diff --git a/sdk/cosmos/azure_data_cosmos/CHANGELOG.md b/sdk/cosmos/azure_data_cosmos/CHANGELOG.md index ce1ae257d1d..c3e475ed767 100644 --- a/sdk/cosmos/azure_data_cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure_data_cosmos/CHANGELOG.md @@ -8,6 +8,7 @@ - Removed the `request_url()` accessor (gated on the `fault_injection` feature) from `ItemResponse`/`ResourceResponse`/`BatchResponse`. Driver-routed operations never populated it, so it always returned `None` in current usage. - `CosmosClientBuilder::with_user_agent_suffix` (and `CosmosClientOptions::with_user_agent_suffix`) now take `UserAgentSuffix` instead of `impl Into`. Callers passing a `&str` or `String` must construct the value explicitly via `UserAgentSuffix::new` (panics on invalid input) or `UserAgentSuffix::try_new` (returns `Option`). Validation rules (max 25 characters, HTTP-header-safe) are now enforced at the construction site instead of being applied silently inside the builder. ([#4368](https://github.com/Azure/azure-sdk-for-rust/pull/4368)) +- `ContainerClient::query_items()` now takes a `QueryScope` (`QueryScope::partition(...)`, `QueryScope::feed_range(...)`, or `QueryScope::full_container()`) instead of a partition key where `()` represented cross-partition queries. - Replaced `CosmosDiagnostics` with `CosmosDiagnosticsContext` (a re-export of `azure_data_cosmos_driver::diagnostics::DiagnosticsContext`). All response types now return `Arc` from `diagnostics()` (the returned `Arc` derefs transparently to `CosmosDiagnosticsContext` for read-only inspection, and can be retained alongside a consumed response body). The previous `activity_id() -> Option<&str>` and `server_duration_ms() -> Option` accessors on `CosmosDiagnostics` are replaced by `CosmosDiagnosticsContext::activity_id() -> &ActivityId` and per-request server timing via `CosmosDiagnosticsContext::requests()[i].server_duration_ms()`. diff --git a/sdk/cosmos/azure_data_cosmos/examples/cosmos/query.rs b/sdk/cosmos/azure_data_cosmos/examples/cosmos/query.rs index 78acae5c1a2..786460a3191 100644 --- a/sdk/cosmos/azure_data_cosmos/examples/cosmos/query.rs +++ b/sdk/cosmos/azure_data_cosmos/examples/cosmos/query.rs @@ -3,7 +3,7 @@ use std::error::Error; -use azure_data_cosmos::{CosmosClient, PartitionKey}; +use azure_data_cosmos::{query::QueryScope, CosmosClient}; use clap::{Args, Subcommand}; use futures::TryStreamExt; @@ -55,13 +55,14 @@ impl QueryCommand { let db_client = client.database_client(&database); let container_client = db_client.container_client(&container).await?; - let pk = match partition_key { - Some(pk) => PartitionKey::from(pk), - None => PartitionKey::EMPTY, + let scope = match partition_key { + Some(pk) => QueryScope::partition(pk), + None => QueryScope::full_container(), }; - let mut items = - container_client.query_items::(&query, pk, None)?; + let mut items = container_client + .query_items::(&query, scope, None) + .await?; println!("Items:"); while let Some(item) = items.try_next().await? { @@ -70,7 +71,7 @@ impl QueryCommand { Ok(()) } Subcommands::Databases { query } => { - let mut dbs = client.query_databases(query, None)?; + let mut dbs = client.query_databases(query, None).await?; println!("Databases:"); while let Some(item) = dbs.try_next().await? { @@ -80,7 +81,7 @@ impl QueryCommand { } Subcommands::Containers { database, query } => { let db_client = client.database_client(&database); - let mut dbs = db_client.query_containers(query, None)?; + let mut dbs = db_client.query_containers(query, None).await?; println!("Containers:"); while let Some(item) = dbs.try_next().await? { diff --git a/sdk/cosmos/azure_data_cosmos/src/clients/container_client.rs b/sdk/cosmos/azure_data_cosmos/src/clients/container_client.rs index 6fef84ecc8c..4a0d0e06073 100644 --- a/sdk/cosmos/azure_data_cosmos/src/clients/container_client.rs +++ b/sdk/cosmos/azure_data_cosmos/src/clients/container_client.rs @@ -3,7 +3,6 @@ use crate::{ clients::{offers_client, ClientContext}, - feed_range::FeedRange, models::{ BatchResponse, ContainerProperties, ItemResponse, ResourceResponse, ThroughputProperties, }, @@ -11,9 +10,10 @@ use crate::{ BatchOptions, Precondition, QueryOptions, ReadContainerOptions, ReadFeedRangesOptions, SessionToken, }, + query::QueryScope, transactional_batch::TransactionalBatch, - DeleteContainerOptions, FeedItemIterator, ItemReadOptions, ItemWriteOptions, PartitionKey, - Query, ReplaceContainerOptions, ThroughputOptions, + DeleteContainerOptions, FeedItemIterator, FeedRange, ItemReadOptions, ItemWriteOptions, + PartitionKey, Query, ReplaceContainerOptions, ThroughputOptions, }; use super::ThroughputPoller; @@ -84,7 +84,7 @@ impl ContainerClient { let driver_response = self .context .driver - .execute_operation(operation, OperationOptions::default()) + .execute_point_operation(operation, OperationOptions::default()) .await?; Ok(ResourceResponse::new( @@ -138,7 +138,7 @@ impl ContainerClient { let driver_response = self .context .driver - .execute_operation(operation, operation_options) + .execute_point_operation(operation, operation_options) .await?; Ok(ResourceResponse::new( @@ -230,7 +230,7 @@ impl ContainerClient { let driver_response = self .context .driver - .execute_operation(operation, OperationOptions::default()) + .execute_point_operation(operation, OperationOptions::default()) .await?; Ok(ResourceResponse::new( @@ -316,7 +316,7 @@ impl ContainerClient { // Build the driver's item reference from our stored container metadata. let item_ref = ItemReference::from_name( &self.container_ref, - partition_key.into().into_driver_partition_key(), + partition_key.into(), item_id.to_owned(), ); @@ -328,7 +328,7 @@ impl ContainerClient { let driver_response = self .context .driver - .execute_operation(operation, options.operation) + .execute_point_operation(operation, options.operation) .await?; // Bridge the driver response to the SDK response type. @@ -414,7 +414,7 @@ impl ContainerClient { // Build the driver's item reference from our stored container metadata. let item_ref = ItemReference::from_name( &self.container_ref, - partition_key.into().into_driver_partition_key(), + partition_key.into(), item_id.to_owned(), ); @@ -426,7 +426,7 @@ impl ContainerClient { let driver_response = self .context .driver - .execute_operation(operation, options.operation) + .execute_point_operation(operation, options.operation) .await?; // Bridge the driver response to the SDK response type. @@ -516,7 +516,7 @@ impl ContainerClient { // Build the driver's item reference from our stored container metadata. let item_ref = ItemReference::from_name( &self.container_ref, - partition_key.into().into_driver_partition_key(), + partition_key.into(), item_id.to_owned(), ); @@ -528,7 +528,7 @@ impl ContainerClient { let driver_response = self .context .driver - .execute_operation(operation, options.operation) + .execute_point_operation(operation, options.operation) .await?; // Bridge the driver response to the SDK response type. @@ -576,7 +576,7 @@ impl ContainerClient { // Build the driver's item reference from our stored container metadata. let item_ref = ItemReference::from_name( &self.container_ref, - partition_key.into().into_driver_partition_key(), + partition_key.into(), item_id.to_owned(), ); @@ -588,7 +588,7 @@ impl ContainerClient { let driver_response = self .context .driver - .execute_operation(operation, options.operation) + .execute_point_operation(operation, options.operation) .await?; // Bridge the driver response to the SDK response type. @@ -628,7 +628,7 @@ impl ContainerClient { // Build the driver's item reference from our stored container metadata. let item_ref = ItemReference::from_name( &self.container_ref, - partition_key.into().into_driver_partition_key(), + partition_key.into(), item_id.to_owned(), ); @@ -640,7 +640,7 @@ impl ContainerClient { let driver_response = self .context .driver - .execute_operation(operation, options.operation) + .execute_point_operation(operation, options.operation) .await?; // Bridge the driver response to the SDK response type. @@ -661,7 +661,7 @@ impl ContainerClient { /// # Arguments /// /// * `query` - The query to execute. - /// * `partition_key` - The partition key to scope the query on, or specify an empty key (`()`) to perform a cross-partition query. + /// * `scope` - The [`QueryScope`] specifying the scope of the query. /// * `options` - Optional parameters for the request. /// /// # Cross Partition Queries @@ -672,11 +672,12 @@ impl ContainerClient { /// /// # Examples /// - /// The `query` and `partition_key` parameters accept anything that can be transformed [`Into`] their relevant types. + /// The `query` parameter accepts anything that can be transformed [`Into`] a [`Query`], and `scope` controls partition targeting. /// This allows simple queries without parameters to be expressed easily: /// /// ```rust,no_run /// # async fn doc() -> Result<(), Box> { + /// # use azure_data_cosmos::query::QueryScope; /// # let container_client: azure_data_cosmos::clients::ContainerClient = panic!("this is a non-running example"); /// #[derive(serde::Deserialize)] /// struct Customer { @@ -685,8 +686,9 @@ impl ContainerClient { /// } /// let items = container_client.query_items::( /// "SELECT * FROM c", - /// "some_partition_key", - /// None)?; + /// QueryScope::partition("some_partition_key"), + /// None, + /// ).await?; /// # } /// ``` /// @@ -694,7 +696,7 @@ impl ContainerClient { /// /// ```rust,no_run /// # async fn doc() -> Result<(), Box> { - /// use azure_data_cosmos::Query; + /// use azure_data_cosmos::{query::QueryScope, Query}; /// # let container_client: azure_data_cosmos::clients::ContainerClient = panic!("this is a non-running example"); /// #[derive(serde::Deserialize)] /// struct Customer { @@ -703,34 +705,49 @@ impl ContainerClient { /// } /// let query = Query::from("SELECT COUNT(*) FROM c WHERE c.customer_id = @customer_id") /// .with_parameter("@customer_id", 42)?; - /// let items = container_client.query_items::(query, "some_partition_key", None)?; + /// let items = container_client + /// .query_items::(query, QueryScope::partition("some_partition_key"), None).await?; /// # } /// ``` /// /// See [`PartitionKey`](crate::PartitionKey) for more information on how to specify a partition key, and [`Query`] for more information on how to specify a query. - pub fn query_items( + pub async fn query_items( &self, query: impl Into, - partition_key: impl Into, + scope: QueryScope, options: Option, ) -> azure_core::Result> { let options = options.unwrap_or_default(); - let partition_key: PartitionKey = partition_key.into(); let query = query.into(); - let driver_pk = partition_key.into_driver_partition_key(); let container_ref = self.container_ref.clone(); - let factory = - move || CosmosOperation::query_items(container_ref.clone(), driver_pk.clone()); - crate::query::executor::QueryExecutor::new( + // The first operation to execute in the query items flow. + // This holds the session token provided by the user, if any. + let mut initial_operation = + CosmosOperation::query_items(container_ref.clone(), scope.into()) + .with_body(serde_json::to_vec(&query)?); + if let Some(token) = options.session_token { + initial_operation = initial_operation.with_session_token(token); + } + if let Some(max_item_count) = options.max_item_count { + initial_operation = initial_operation.with_max_item_count(max_item_count); + } + let plan = self + .context + .driver + .plan_operation( + initial_operation, + &options.operation, + options.continuation_token.as_ref(), + ) + .await?; + Ok(FeedItemIterator::new( self.context.driver.clone(), - factory, - query, + Some(self.container_ref.clone()), + plan, options.operation, - options.session_token, - ) - .into_stream() + )) } /// Executes a transactional batch of operations. @@ -781,7 +798,7 @@ impl ContainerClient { ) -> azure_core::Result { let options = options.unwrap_or_default(); let body = serde_json::to_vec(batch.operations())?; - let driver_pk = batch.partition_key().clone().into_driver_partition_key(); + let driver_pk = batch.partition_key().clone(); let operation = CosmosOperation::batch(self.container_ref.clone(), driver_pk).with_body(body); @@ -790,7 +807,7 @@ impl ContainerClient { let driver_response = self .context .driver - .execute_operation(operation, options.operation) + .execute_point_operation(operation, options.operation) .await?; Ok(BatchResponse::new( @@ -840,10 +857,7 @@ impl ContainerClient { )); } - ranges - .iter() - .map(FeedRange::from_partition_key_range) - .collect() + ranges.iter().map(FeedRange::try_from).collect() } /// Returns the [`FeedRange`]s covering the given partition key. @@ -856,7 +870,7 @@ impl ContainerClient { options: Option, ) -> azure_core::Result> { let partition_key = partition_key.into(); - let driver_pk = partition_key.into_driver_partition_key(); + let driver_pk = partition_key; let options = options.unwrap_or_default(); let pk_def = self.container_ref.partition_key_definition(); let values = driver_pk.values(); @@ -925,15 +939,9 @@ impl ContainerClient { )); } - ranges - .iter() - .map(FeedRange::from_partition_key_range) - .collect() + ranges.iter().map(FeedRange::try_from).collect() } else { - ranges - .iter() - .map(FeedRange::from_partition_key_range) - .collect() + ranges.iter().map(FeedRange::try_from).collect() } } diff --git a/sdk/cosmos/azure_data_cosmos/src/clients/cosmos_client.rs b/sdk/cosmos/azure_data_cosmos/src/clients/cosmos_client.rs index ee0121c2f86..2b4d9f775ff 100644 --- a/sdk/cosmos/azure_data_cosmos/src/clients/cosmos_client.rs +++ b/sdk/cosmos/azure_data_cosmos/src/clients/cosmos_client.rs @@ -117,29 +117,37 @@ impl CosmosClient { /// # async fn doc() -> Result<(), Box> { /// # use azure_data_cosmos::CosmosClient; /// # let client: CosmosClient = panic!("this is a non-running example"); - /// let dbs = client.query_databases( - /// "SELECT * FROM dbs", - /// None)?; + /// let dbs = client + /// .query_databases("SELECT * FROM dbs", None) + /// .await?; /// # } /// ``` /// /// See [`Query`] for more information on how to specify a query. - pub fn query_databases( + pub async fn query_databases( &self, query: impl Into, - _options: Option, + #[allow(unused_variables, reason = "This parameter may be used in the future")] + options: Option, ) -> azure_core::Result> { + let query = query.into(); let account = self.context.driver.account().clone(); - let factory = move || CosmosOperation::query_databases(account.clone()); + let initial_operation = + CosmosOperation::query_databases(account).with_body(serde_json::to_vec(&query)?); + let operation_options = OperationOptions::default(); + + let plan = self + .context + .driver + .plan_operation(initial_operation, &operation_options, None) + .await?; - crate::query::executor::QueryExecutor::new( + Ok(FeedItemIterator::new( self.context.driver.clone(), - factory, - query.into(), - Default::default(), None, - ) - .into_stream() + plan, + operation_options, + )) } /// Creates a new database. @@ -173,7 +181,7 @@ impl CosmosClient { let driver_response = self .context .driver - .execute_operation(operation, operation_options) + .execute_point_operation(operation, operation_options) .await?; Ok(ResourceResponse::new( @@ -181,3 +189,20 @@ impl CosmosClient { )) } } + +#[cfg(test)] +mod tests { + use super::*; + + /// Compile-time assertion that `CosmosClient` async method futures are `Send`. + /// + /// This function is never called; it only needs to compile. + /// If any future is not `Send`, compilation will fail. + #[allow(dead_code, unreachable_code, unused_variables)] + fn _assert_futures_are_send() { + fn assert_send(_: T) {} + let client: &CosmosClient = todo!(); + assert_send(client.query_databases(Query::from("SELECT * FROM dbs"), todo!())); + assert_send(client.create_database(todo!(), todo!())); + } +} diff --git a/sdk/cosmos/azure_data_cosmos/src/clients/database_client.rs b/sdk/cosmos/azure_data_cosmos/src/clients/database_client.rs index 98a606a5adc..ca69e442aab 100644 --- a/sdk/cosmos/azure_data_cosmos/src/clients/database_client.rs +++ b/sdk/cosmos/azure_data_cosmos/src/clients/database_client.rs @@ -83,7 +83,7 @@ impl DatabaseClient { let driver_response = self .context .driver - .execute_operation(operation, OperationOptions::default()) + .execute_point_operation(operation, OperationOptions::default()) .await?; Ok(ResourceResponse::new( @@ -107,33 +107,37 @@ impl DatabaseClient { /// # async fn doc() -> Result<(), Box> { /// # use azure_data_cosmos::clients::DatabaseClient; /// # let db_client: DatabaseClient = panic!("this is a non-running example"); - /// let containers = db_client.query_containers( - /// "SELECT * FROM dbs", - /// None)?; + /// let containers = db_client + /// .query_containers("SELECT * FROM dbs", None) + /// .await?; /// # } /// ``` /// /// See [`Query`] for more information on how to specify a query. #[allow(unused_variables, reason = "This parameter may be used in the future")] - pub fn query_containers( + pub async fn query_containers( &self, query: impl Into, + #[allow(unused_variables, reason = "This parameter may be used in the future")] options: Option, ) -> azure_core::Result> { - let db_ref = DatabaseReference::from_name( - self.context.driver.account().clone(), - self.database_id.clone(), - ); - let factory = move || CosmosOperation::query_containers(db_ref.clone()); + let query = query.into(); + let initial_operation = CosmosOperation::query_containers(self.database_ref.clone()) + .with_body(serde_json::to_vec(&query)?); + let operation_options = OperationOptions::default(); + + let plan = self + .context + .driver + .plan_operation(initial_operation, &operation_options, None) + .await?; - crate::query::executor::QueryExecutor::new( + Ok(FeedItemIterator::new( self.context.driver.clone(), - factory, - query.into(), - Default::default(), None, - ) - .into_stream() + plan, + operation_options, + )) } /// Creates a new container. @@ -168,7 +172,7 @@ impl DatabaseClient { let driver_response = self .context .driver - .execute_operation(operation, operation_options) + .execute_point_operation(operation, operation_options) .await?; Ok(ResourceResponse::new( @@ -192,7 +196,7 @@ impl DatabaseClient { let driver_response = self .context .driver - .execute_operation(operation, OperationOptions::default()) + .execute_point_operation(operation, OperationOptions::default()) .await?; Ok(ResourceResponse::new( @@ -275,3 +279,25 @@ impl DatabaseClient { .await } } + +#[cfg(test)] +mod tests { + use super::*; + + /// Compile-time assertion that `DatabaseClient` async method futures are `Send`. + /// + /// This function is never called; it only needs to compile. + /// If any future is not `Send`, compilation will fail. + #[allow(dead_code, unreachable_code, unused_variables)] + fn _assert_futures_are_send() { + fn assert_send(_: T) {} + let client: &DatabaseClient = todo!(); + assert_send(client.container_client(todo!())); + assert_send(client.read(todo!())); + assert_send(client.query_containers(Query::from("SELECT * FROM c"), todo!())); + assert_send(client.create_container(todo!(), todo!())); + assert_send(client.delete(todo!())); + assert_send(client.read_throughput(todo!())); + assert_send(client.begin_replace_throughput(todo!(), todo!())); + } +} diff --git a/sdk/cosmos/azure_data_cosmos/src/clients/offers_client.rs b/sdk/cosmos/azure_data_cosmos/src/clients/offers_client.rs index 5e88ab4a173..de3442f03e5 100644 --- a/sdk/cosmos/azure_data_cosmos/src/clients/offers_client.rs +++ b/sdk/cosmos/azure_data_cosmos/src/clients/offers_client.rs @@ -38,7 +38,7 @@ pub(crate) async fn find_offer( headers.insert(CONTENT_TYPE, HeaderValue::from("application/query+json")); let options = OperationOptions::default().with_custom_headers(headers); - let driver_response = driver.execute_operation(operation, options).await?; + let driver_response = driver.execute_point_operation(operation, options).await?; tracing::debug!( activity_id = ?driver_response.headers().activity_id, request_charge = ?driver_response.headers().request_charge, @@ -56,7 +56,7 @@ pub(crate) async fn read_offer_by_id( ) -> azure_core::Result> { let operation = CosmosOperation::read_offer(account.clone(), offer_id.to_owned()); let driver_response = driver - .execute_operation(operation, OperationOptions::default()) + .execute_point_operation(operation, OperationOptions::default()) .await?; Ok(crate::driver_bridge::driver_response_to_cosmos_response( driver_response, @@ -105,7 +105,9 @@ pub(crate) async fn begin_replace( opts }; - let driver_response = driver.execute_operation(operation, replace_options).await?; + let driver_response = driver + .execute_point_operation(operation, replace_options) + .await?; let response = crate::driver_bridge::driver_response_to_cosmos_response(driver_response); diff --git a/sdk/cosmos/azure_data_cosmos/src/driver_bridge.rs b/sdk/cosmos/azure_data_cosmos/src/driver_bridge.rs index 2bb19307c89..584df8c3fdd 100644 --- a/sdk/cosmos/azure_data_cosmos/src/driver_bridge.rs +++ b/sdk/cosmos/azure_data_cosmos/src/driver_bridge.rs @@ -338,12 +338,7 @@ mod tests { Some("totalExecutionTimeInMs=1.23;queryCompileTimeInMs=0.01"), ); - let rt = tokio::runtime::Runtime::new().unwrap(); - let page = rt - .block_on(QueryFeedPage::::from_response( - cosmos_response, - )) - .unwrap(); + let page = QueryFeedPage::::from_response(cosmos_response).unwrap(); assert_eq!( page.index_metrics(), Some(r#"{"UtilizedSingleIndexes":[]}"#) diff --git a/sdk/cosmos/azure_data_cosmos/src/feed.rs b/sdk/cosmos/azure_data_cosmos/src/feed.rs index 681e7a94c22..03cf70bfa71 100644 --- a/sdk/cosmos/azure_data_cosmos/src/feed.rs +++ b/sdk/cosmos/azure_data_cosmos/src/feed.rs @@ -1,21 +1,25 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -use std::{pin::Pin, sync::Arc, task}; +use std::{marker::PhantomData, pin::Pin, sync::Arc, task}; use azure_core::http::{ headers::Headers, pager::{PagerContinuation, PagerResult}, }; -use azure_data_cosmos_driver::models::CosmosResponseHeaders; -use futures::stream::BoxStream; +use azure_data_cosmos_driver::{ + models::{ContainerReference, CosmosResponse as DriverResponse, CosmosResponseHeaders}, + options::OperationOptions, + CosmosDriver, OperationPlan, +}; +use futures::future::BoxFuture; use futures::Stream; use serde::{de::DeserializeOwned, Deserialize}; use crate::{ - constants, + constants, driver_bridge, models::{CosmosDiagnosticsContext, CosmosResponse}, - SessionToken, + ContinuationToken, SessionToken, }; /// Represents a single page of results from a Cosmos DB feed. @@ -217,9 +221,7 @@ pub(crate) struct FeedBody { } impl QueryFeedPage { - pub(crate) async fn from_response( - response: CosmosResponse>, - ) -> azure_core::Result { + pub(crate) fn from_response(response: CosmosResponse>) -> azure_core::Result { let raw_headers = response.headers().clone(); let continuation = raw_headers.get_optional_string(&constants::CONTINUATION); let cosmos_headers = response.cosmos_headers().clone(); @@ -242,75 +244,257 @@ impl QueryFeedPage { } } +type DriverPageFuture = + BoxFuture<'static, (OperationPlan, azure_core::Result>)>; + +/// Live pipeline state held by [`FeedPageIterator`] / [`FeedItemIterator`]. +/// +/// Owns the [`OperationPlan`] directly (rather than burying it inside an +/// `unfold` closure) so that +/// [`FeedPageIterator::to_continuation_token`] can snapshot it between polls. +struct LiveState { + driver: Arc, + container: Option, + options: OperationOptions, + /// Always `Some` while no page fetch is in flight. + plan: Option, + /// `Some` while a page fetch is pending. + in_flight: Option, + exhausted: bool, +} + +impl LiveState { + fn new( + driver: Arc, + container: Option, + plan: OperationPlan, + options: OperationOptions, + ) -> Self { + Self { + driver, + container, + options, + plan: Some(plan), + in_flight: None, + exhausted: false, + } + } + + fn poll_next_page( + &mut self, + cx: &mut task::Context<'_>, + ) -> task::Poll>>> { + if self.exhausted { + return task::Poll::Ready(None); + } + + if self.in_flight.is_none() { + // Move the plan into a future. The future returns the plan back so + // we can store it again between polls. + let mut plan = self + .plan + .take() + .expect("plan must be present between polls"); + let driver = Arc::clone(&self.driver); + let container = self.container.clone(); + let options = self.options.clone(); + let fut: DriverPageFuture = Box::pin(async move { + let result = driver.execute_plan(&mut plan, container, options).await; + (plan, result) + }); + self.in_flight = Some(fut); + } + + let fut = self.in_flight.as_mut().expect("future just installed"); + let (plan, result) = match fut.as_mut().poll(cx) { + task::Poll::Pending => return task::Poll::Pending, + task::Poll::Ready(out) => out, + }; + self.in_flight = None; + self.plan = Some(plan); + + match result { + Ok(None) => { + self.exhausted = true; + task::Poll::Ready(None) + } + Err(err) => { + self.exhausted = true; + task::Poll::Ready(Some(Err(err))) + } + Ok(Some(driver_response)) => { + let response = driver_bridge::driver_response_to_cosmos_response::>( + driver_response, + ); + match QueryFeedPage::from_response(response) { + Ok(page) => task::Poll::Ready(Some(Ok(page))), + Err(err) => { + self.exhausted = true; + task::Poll::Ready(Some(Err(err))) + } + } + } + } + } + + fn to_continuation_token(&self) -> azure_core::Result { + let plan = self.plan.as_ref().ok_or_else(|| { + azure_core::Error::with_message( + azure_core::error::ErrorKind::Other, + "to_continuation_token called while a page fetch is in flight", + ) + })?; + plan.to_continuation_token() + } +} + +/// Internal source of pages for [`FeedPageIterator`] and [`FeedItemIterator`]. +/// +/// Production iterators use the [`Live`](Self::Live) variant which drives the +/// underlying [`OperationPlan`]. Unit tests use [`Synthetic`](Self::Synthetic) +/// to inject a pre-built sequence of pages. +enum PageSource { + Live(Box), + #[cfg(test)] + Synthetic(std::collections::VecDeque>>), + #[cfg(not(test))] + #[allow(dead_code)] + _Phantom(PhantomData T>), +} + +impl PageSource { + fn poll_next_page( + &mut self, + cx: &mut task::Context<'_>, + ) -> task::Poll>>> { + match self { + PageSource::Live(state) => state.poll_next_page::(cx), + #[cfg(test)] + PageSource::Synthetic(pages) => task::Poll::Ready(pages.pop_front()), + #[cfg(not(test))] + PageSource::_Phantom(_) => task::Poll::Ready(None), + } + } +} + /// Represents a stream of items from a Cosmos DB query. /// /// See [`QueryFeedPage`] for more details on Cosmos DB feeds. -#[pin_project::pin_project] pub struct FeedItemIterator { - #[pin] - pages: BoxStream<'static, azure_core::Result>>, + source: PageSource, current: Option>, + _marker: PhantomData T>, } -impl FeedItemIterator { - /// Creates a new `FeedItemIterator` from a stream of pages. +impl FeedItemIterator { + /// Creates a new `FeedItemIterator` backed by the given operation plan. pub(crate) fn new( - stream: impl Stream>> + Send + 'static, + driver: Arc, + container: Option, + plan: OperationPlan, + options: OperationOptions, ) -> Self { Self { - pages: Box::pin(stream), + source: PageSource::Live(Box::new(LiveState::new(driver, container, plan, options))), current: None, + _marker: PhantomData, } } + /// Converts this item iterator into a page iterator, yielding full pages + /// instead of individual items. + /// + /// IMPORTANT: This will DISCARD any items from the current page that have + /// not yet been yielded by the item iterator. Use this method before + /// consuming any items if you want to switch to page-based iteration. pub fn into_pages(self) -> FeedPageIterator { - FeedPageIterator(self.pages) + FeedPageIterator { + source: self.source, + _marker: PhantomData, + } } } -impl Stream for FeedItemIterator { +impl Stream for FeedItemIterator { type Item = azure_core::Result; fn poll_next( self: Pin<&mut Self>, cx: &mut task::Context<'_>, ) -> task::Poll> { - let mut this = self.project(); + // Safety: we never move the inner source/current out via Pin. + let this = unsafe { self.get_unchecked_mut() }; loop { if let Some(current) = this.current.as_mut() { if let Some(item) = current.next() { return task::Poll::Ready(Some(Ok(item))); } - - // Reset the iterator and poll for the next page. - *this.current = None; + this.current = None; } - match this.pages.as_mut().poll_next(cx) { - task::Poll::Ready(page) => match page { - Some(Ok(page)) => { - *this.current = Some(page.page.items.into_iter()); - continue; - } - Some(Err(err)) => return task::Poll::Ready(Some(Err(err))), - None => return task::Poll::Ready(None), - }, + match this.source.poll_next_page(cx) { + task::Poll::Ready(Some(Ok(page))) => { + this.current = Some(page.into_items().into_iter()); + continue; + } + task::Poll::Ready(Some(Err(err))) => return task::Poll::Ready(Some(Err(err))), + task::Poll::Ready(None) => return task::Poll::Ready(None), task::Poll::Pending => return task::Poll::Pending, } } } } -pub struct FeedPageIterator(BoxStream<'static, azure_core::Result>>); +/// A stream of pages from a Cosmos DB feed operation. +/// +/// In addition to yielding [`QueryFeedPage`]s like a regular `Stream`, this +/// iterator can be snapshotted into a [`ContinuationToken`] for later +/// resumption via +/// [`to_continuation_token`](Self::to_continuation_token). +pub struct FeedPageIterator { + source: PageSource, + _marker: PhantomData T>, +} + +impl FeedPageIterator { + /// Captures the current iterator position as a [`ContinuationToken`]. + /// + /// Pass the returned token to a subsequent + /// [`ContainerClient::query_items`](crate::clients::ContainerClient::query_items) + /// call (via [`QueryOptions::with_continuation_token`](crate::QueryOptions::with_continuation_token)) + /// to resume the query at the same position. + /// + /// Snapshotting is non-mutating; the iterator may continue to be used + /// afterwards. + /// + /// # Errors + /// + /// Returns an error if a page fetch is currently in flight (the plan + /// state is being mutated and cannot be safely snapshotted). + pub fn to_continuation_token(&self) -> azure_core::Result { + match &self.source { + PageSource::Live(state) => state.to_continuation_token(), + #[cfg(test)] + PageSource::Synthetic(_) => Err(azure_core::Error::with_message( + azure_core::error::ErrorKind::Other, + "synthetic test iterator does not support to_continuation_token", + )), + #[cfg(not(test))] + PageSource::_Phantom(_) => unreachable!(), + } + } +} -impl Stream for FeedPageIterator { +impl Stream for FeedPageIterator { type Item = azure_core::Result>; fn poll_next( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut task::Context<'_>, ) -> task::Poll> { - self.0.as_mut().poll_next(cx) + // Safety: we never move source out via Pin. + let this = unsafe { self.get_unchecked_mut() }; + this.source.poll_next_page(cx) } } @@ -334,6 +518,16 @@ mod tests { } } + fn synthetic_item_iter( + pages: Vec>>, + ) -> FeedItemIterator { + FeedItemIterator { + source: PageSource::Synthetic(pages.into()), + current: None, + _marker: PhantomData, + } + } + #[tokio::test] async fn item_iterator_yields_all_items_from_multiple_pages() { let pages = vec![ @@ -342,9 +536,7 @@ mod tests { Ok(create_test_page(vec![6], None)), ]; - let stream = futures::stream::iter(pages); - let item_iter = FeedItemIterator::new(stream); - + let item_iter = synthetic_item_iter(pages); let items: Vec<_> = item_iter .collect::>() .await @@ -361,9 +553,7 @@ mod tests { Ok(create_test_page(vec![3], None)), ]; - let stream = futures::stream::iter(pages); - let page_iter = FeedItemIterator::new(stream).into_pages(); - + let page_iter = synthetic_item_iter(pages).into_pages(); let page_items: Vec<_> = page_iter .collect::>() .await @@ -383,8 +573,7 @@ mod tests { )), ]; - let stream = futures::stream::iter(pages); - let mut item_iter = FeedItemIterator::new(stream); + let mut item_iter = synthetic_item_iter(pages); // First two items should succeed assert_eq!(item_iter.next().await.unwrap().unwrap(), 1); @@ -402,9 +591,7 @@ mod tests { Ok(create_test_page(vec![2], None)), ]; - let stream = futures::stream::iter(pages); - let item_iter = FeedItemIterator::new(stream); - + let item_iter = synthetic_item_iter(pages); let items: Vec<_> = item_iter .collect::>() .await diff --git a/sdk/cosmos/azure_data_cosmos/src/feed_range.rs b/sdk/cosmos/azure_data_cosmos/src/feed_range.rs deleted file mode 100644 index 831c379e1f3..00000000000 --- a/sdk/cosmos/azure_data_cosmos/src/feed_range.rs +++ /dev/null @@ -1,445 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -//! Types for working with feed ranges in Azure Cosmos DB. -//! -//! A [`FeedRange`] represents a contiguous range of partitions in a Cosmos DB container, -//! defined by effective partition key (EPK) boundaries. Feed ranges enable: -//! -//! - Parallel query processing by distributing ranges across workers -//! - Scoped change feed consumption for specific partitions -//! - Workload distribution across multiple consumers -//! -//! # Examples -//! -//! ```rust,no_run -//! # use azure_data_cosmos::clients::ContainerClient; -//! # async fn example(container: ContainerClient) -> azure_core::Result<()> { -//! // Get physical partition feed ranges -//! let ranges = container.read_feed_ranges(None).await?; -//! println!("Container has {} physical partitions", ranges.len()); -//! -//! // Serialize/deserialize for storage or transfer -//! let serialized = ranges[0].to_string(); -//! let restored: azure_data_cosmos::FeedRange = serialized.parse()?; -//! assert_eq!(ranges[0], restored); -//! # Ok(()) -//! # } -//! ``` - -use azure_core::fmt::SafeDebug; -use base64::Engine; -use serde::{Deserialize, Serialize}; -use std::fmt; -use std::str::FromStr; - -use azure_data_cosmos_driver::models::partition_key_range::PartitionKeyRange; - -use crate::hash::EffectivePartitionKey; -use crate::hash::{MAX_EXCLUSIVE_EFFECTIVE_PARTITION_KEY, MIN_INCLUSIVE_EFFECTIVE_PARTITION_KEY}; - -/// An opaque representation of a contiguous range of partitions in a Cosmos DB container. -/// -/// Feed ranges are defined by effective partition key (EPK) boundaries and map to one or more -/// physical partitions. They are obtained from [`ContainerClient::read_feed_ranges()`](crate::clients::ContainerClient::read_feed_ranges) -/// or [`ContainerClient::feed_range_from_partition_key()`](crate::clients::ContainerClient::feed_range_from_partition_key). -/// -/// Feed ranges can be serialized to strings (via [`std::fmt::Display`]/[`std::str::FromStr`]) for storage or transfer -/// between processes. The serialization format is base64-encoded JSON, compatible with other -/// Azure Cosmos DB SDKs. -/// -/// # Serialization Formats -/// -/// `FeedRange` supports two distinct serialization formats: -/// -/// - **[`Display`](std::fmt::Display)/[`FromStr`]** — base64-encoded JSON, intended for string storage and cross-SDK transfer. -/// - **[`Serialize`]/[`Deserialize`]** — structured JSON (`{"Range": {...}}`), intended for embedding in JSON documents. -/// -/// These formats are **not interchangeable**: a value serialized with one cannot be deserialized with the other. -#[derive(Clone, SafeDebug, PartialEq, Eq, Hash)] -#[non_exhaustive] -pub struct FeedRange { - pub(crate) min_inclusive: EffectivePartitionKey, - pub(crate) max_exclusive: EffectivePartitionKey, -} - -/// JSON wire format matching the cross-SDK feed range representation. -/// -/// Example: -/// ```json -/// {"Range": {"min": "", "max": "FF", "isMinInclusive": true, "isMaxInclusive": false}} -/// ``` -#[derive(Serialize, Deserialize)] -struct FeedRangeJson { - #[serde(rename = "Range")] - range: RangeJson, -} - -#[derive(Serialize, Deserialize)] -struct RangeJson { - min: String, - max: String, - #[serde(rename = "isMinInclusive")] - is_min_inclusive: bool, - #[serde(rename = "isMaxInclusive")] - is_max_inclusive: bool, -} - -impl FeedRange { - /// Creates a feed range covering the entire partition key space. - /// - /// This range spans from the minimum to maximum effective partition key values, - /// encompassing all partitions in a container. - pub fn full() -> Self { - Self { - min_inclusive: EffectivePartitionKey::from(MIN_INCLUSIVE_EFFECTIVE_PARTITION_KEY), - max_exclusive: EffectivePartitionKey::from(MAX_EXCLUSIVE_EFFECTIVE_PARTITION_KEY), - } - } - - /// Returns `true` if this feed range is entirely contained within `other`. - pub(crate) fn is_subset_of(&self, other: &FeedRange) -> bool { - other.min_inclusive <= self.min_inclusive && other.max_exclusive >= self.max_exclusive - } - - /// Returns `true` if this feed range and `other` share any portion of the EPK space. - /// - /// Two feed ranges overlap when one starts before the other ends and vice versa. - pub(crate) fn overlaps(&self, other: &FeedRange) -> bool { - self.min_inclusive < other.max_exclusive && other.min_inclusive < self.max_exclusive - } - - /// Returns `true` if this feed range can be combined with `other`. - /// - /// Two ranges can be combined when they overlap or are adjacent - /// (one's max equals the other's min). - pub(crate) fn can_merge(&self, other: &FeedRange) -> bool { - self.max_exclusive >= other.min_inclusive && other.max_exclusive >= self.min_inclusive - } - - /// Combines this feed range with `other` into a bounding range. - pub(crate) fn merge_with(&self, other: &FeedRange) -> FeedRange { - debug_assert!( - self.can_merge(other), - "merge_with called on disjoint ranges" - ); - FeedRange { - min_inclusive: std::cmp::min(self.min_inclusive.clone(), other.min_inclusive.clone()), - max_exclusive: std::cmp::max(self.max_exclusive.clone(), other.max_exclusive.clone()), - } - } - - /// Creates a `FeedRange` from a driver `PartitionKeyRange`. - /// - /// Partition key ranges from the service always use `[min, max)` semantics - /// (min inclusive, max exclusive). Returns an error if the range is inverted. - pub(crate) fn from_partition_key_range(pkr: &PartitionKeyRange) -> azure_core::Result { - if pkr.min_inclusive > pkr.max_exclusive { - return Err(azure_core::Error::with_message( - azure_core::error::ErrorKind::DataConversion, - "partition key range min_inclusive must be <= max_exclusive", - )); - } - Ok(Self { - min_inclusive: EffectivePartitionKey::from(pkr.min_inclusive.as_str()), - max_exclusive: EffectivePartitionKey::from(pkr.max_exclusive.as_str()), - }) - } - - /// Builds the JSON wire-format representation for serialization. - fn to_json(&self) -> FeedRangeJson { - FeedRangeJson { - range: RangeJson { - min: self.min_inclusive.as_str().to_owned(), - max: self.max_exclusive.as_str().to_owned(), - is_min_inclusive: true, - is_max_inclusive: false, - }, - } - } - - /// Validates and constructs a `FeedRange` from deserialized JSON fields. - /// - /// Checks inclusivity flags and min ≤ max ordering. - fn from_json(json: FeedRangeJson) -> azure_core::Result { - if !json.range.is_min_inclusive || json.range.is_max_inclusive { - return Err(azure_core::Error::with_message( - azure_core::error::ErrorKind::DataConversion, - "feed range must have [min, max) semantics (isMinInclusive=true, isMaxInclusive=false)", - )); - } - - let min = EffectivePartitionKey::from(json.range.min); - let max = EffectivePartitionKey::from(json.range.max); - - if min > max { - return Err(azure_core::Error::with_message( - azure_core::error::ErrorKind::DataConversion, - "feed range min must be less than or equal to max", - )); - } - - Ok(Self { - min_inclusive: min, - max_exclusive: max, - }) - } -} - -impl fmt::Display for FeedRange { - /// Formats this feed range as a base64-encoded JSON string. - /// - /// The output is compatible with other Azure Cosmos DB SDKs and can be - /// parsed back using [`std::str::FromStr`]. - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let json_str = serde_json::to_string(&self.to_json()).map_err(|_| fmt::Error)?; - let encoded = base64::engine::general_purpose::STANDARD.encode(json_str.as_bytes()); - f.write_str(&encoded) - } -} - -impl FromStr for FeedRange { - type Err = azure_core::Error; - - /// Parses a feed range from a base64-encoded JSON string. - /// - /// The input should be a string produced by [`std::fmt::Display`] or by another Azure Cosmos DB SDK. - fn from_str(s: &str) -> Result { - let decoded_bytes = base64::engine::general_purpose::STANDARD - .decode(s) - .map_err(|e| azure_core::Error::new(azure_core::error::ErrorKind::DataConversion, e))?; - - let json: FeedRangeJson = serde_json::from_slice(&decoded_bytes) - .map_err(|e| azure_core::Error::new(azure_core::error::ErrorKind::DataConversion, e))?; - - Self::from_json(json) - } -} - -impl Serialize for FeedRange { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - self.to_json().serialize(serializer) - } -} - -impl<'de> Deserialize<'de> for FeedRange { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - let json = FeedRangeJson::deserialize(deserializer)?; - Self::from_json(json).map_err(|e| serde::de::Error::custom(e.to_string())) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn full_range() { - let full = FeedRange::full(); - assert_eq!(full.min_inclusive.as_str(), ""); - assert_eq!(full.max_exclusive.as_str(), "FF"); - } - - #[test] - fn is_subset_of_full() { - let full = FeedRange::full(); - let sub = FeedRange { - min_inclusive: EffectivePartitionKey::from("00"), - max_exclusive: EffectivePartitionKey::from("80"), - }; - assert!(sub.is_subset_of(&full)); - assert!(!full.is_subset_of(&sub)); - } - - #[test] - fn is_subset_of_self() { - let range = FeedRange { - min_inclusive: EffectivePartitionKey::from("20"), - max_exclusive: EffectivePartitionKey::from("80"), - }; - assert!(range.is_subset_of(&range)); - } - - #[test] - fn overlaps_basic() { - let a = FeedRange { - min_inclusive: EffectivePartitionKey::from("00"), - max_exclusive: EffectivePartitionKey::from("50"), - }; - let b = FeedRange { - min_inclusive: EffectivePartitionKey::from("30"), - max_exclusive: EffectivePartitionKey::from("80"), - }; - assert!(a.overlaps(&b)); - assert!(b.overlaps(&a)); - } - - #[test] - fn overlaps_adjacent_no_overlap() { - let a = FeedRange { - min_inclusive: EffectivePartitionKey::from("00"), - max_exclusive: EffectivePartitionKey::from("50"), - }; - let b = FeedRange { - min_inclusive: EffectivePartitionKey::from("50"), - max_exclusive: EffectivePartitionKey::from("FF"), - }; - // Adjacent ranges (a's max == b's min) do NOT overlap because max is exclusive - assert!(!a.overlaps(&b)); - assert!(!b.overlaps(&a)); - } - - #[test] - fn overlaps_disjoint() { - let a = FeedRange { - min_inclusive: EffectivePartitionKey::from("00"), - max_exclusive: EffectivePartitionKey::from("30"), - }; - let b = FeedRange { - min_inclusive: EffectivePartitionKey::from("50"), - max_exclusive: EffectivePartitionKey::from("FF"), - }; - assert!(!a.overlaps(&b)); - assert!(!b.overlaps(&a)); - } - - #[test] - fn display_produces_expected_base64_full_range() { - let range = FeedRange { - min_inclusive: EffectivePartitionKey::from(""), - max_exclusive: EffectivePartitionKey::from("FF"), - }; - assert_eq!( - range.to_string(), - "eyJSYW5nZSI6eyJtaW4iOiIiLCJtYXgiOiJGRiIsImlzTWluSW5jbHVzaXZlIjp0cnVlLCJpc01heEluY2x1c2l2ZSI6ZmFsc2V9fQ==" - ); - } - - #[test] - fn display_produces_expected_base64_sub_range() { - let range = FeedRange { - min_inclusive: EffectivePartitionKey::from("3FFFFFFFFFFF"), - max_exclusive: EffectivePartitionKey::from("7FFFFFFFFFFF"), - }; - assert_eq!( - range.to_string(), - "eyJSYW5nZSI6eyJtaW4iOiIzRkZGRkZGRkZGRkYiLCJtYXgiOiI3RkZGRkZGRkZGRkYiLCJpc01pbkluY2x1c2l2ZSI6dHJ1ZSwiaXNNYXhJbmNsdXNpdmUiOmZhbHNlfX0=" - ); - } - - #[test] - fn from_str_parses_full_range() { - let input = "eyJSYW5nZSI6eyJtaW4iOiIiLCJtYXgiOiJGRiIsImlzTWluSW5jbHVzaXZlIjp0cnVlLCJpc01heEluY2x1c2l2ZSI6ZmFsc2V9fQ=="; - let range: FeedRange = input.parse().unwrap(); - assert_eq!(range.min_inclusive.as_str(), ""); - assert_eq!(range.max_exclusive.as_str(), "FF"); - } - - #[test] - fn from_str_parses_sub_range() { - let input = "eyJSYW5nZSI6eyJtaW4iOiIzRkZGRkZGRkZGRkYiLCJtYXgiOiI3RkZGRkZGRkZGRkYiLCJpc01pbkluY2x1c2l2ZSI6dHJ1ZSwiaXNNYXhJbmNsdXNpdmUiOmZhbHNlfX0="; - let range: FeedRange = input.parse().unwrap(); - assert_eq!(range.min_inclusive.as_str(), "3FFFFFFFFFFF"); - assert_eq!(range.max_exclusive.as_str(), "7FFFFFFFFFFF"); - } - - #[test] - fn serde_json_serializes_to_cross_sdk_format() { - let range = FeedRange { - min_inclusive: EffectivePartitionKey::from(""), - max_exclusive: EffectivePartitionKey::from("FF"), - }; - let json = serde_json::to_string(&range).unwrap(); - - let value: serde_json::Value = serde_json::from_str(&json).unwrap(); - let inner = value.get("Range").expect("expected 'Range' key"); - assert_eq!(inner.get("min").unwrap().as_str().unwrap(), ""); - assert_eq!(inner.get("max").unwrap().as_str().unwrap(), "FF"); - assert!(inner.get("isMinInclusive").unwrap().as_bool().unwrap()); - assert!(!inner.get("isMaxInclusive").unwrap().as_bool().unwrap()); - } - - #[test] - fn serde_json_deserializes_cross_sdk_format() { - let json = - r#"{"Range":{"min":"","max":"FF","isMinInclusive":true,"isMaxInclusive":false}}"#; - let range: FeedRange = serde_json::from_str(json).unwrap(); - assert_eq!(range.min_inclusive.as_str(), ""); - assert_eq!(range.max_exclusive.as_str(), "FF"); - } - - #[test] - fn from_str_invalid_base64() { - let result = "not-valid-base64!!!".parse::(); - assert!(result.is_err()); - } - - #[test] - fn from_str_invalid_json() { - let encoded = base64::engine::general_purpose::STANDARD.encode(b"not json"); - let result = encoded.parse::(); - assert!(result.is_err()); - } - - #[test] - fn from_partition_key_range() { - let pkr = PartitionKeyRange::new("0".to_string(), "".to_string(), "FF".to_string()); - let feed_range = FeedRange::from_partition_key_range(&pkr).unwrap(); - assert_eq!(feed_range.min_inclusive.as_str(), ""); - assert_eq!(feed_range.max_exclusive.as_str(), "FF"); - } - - #[test] - fn cross_sdk_compatibility() { - // Verify that the full range serializes to the same base64 string regardless of platform - let full = FeedRange::full(); - let serialized = full.to_string(); - - // Decode and verify the JSON structure - let decoded = base64::engine::general_purpose::STANDARD - .decode(&serialized) - .unwrap(); - let json: serde_json::Value = serde_json::from_slice(&decoded).unwrap(); - - let range = json.get("Range").unwrap(); - assert_eq!(range.get("min").unwrap().as_str().unwrap(), ""); - assert_eq!(range.get("max").unwrap().as_str().unwrap(), "FF"); - assert!(range.get("isMinInclusive").unwrap().as_bool().unwrap()); - assert!(!range.get("isMaxInclusive").unwrap().as_bool().unwrap()); - } - - #[test] - fn from_str_rejects_max_inclusive() { - let json = r#"{"Range":{"min":"","max":"FF","isMinInclusive":true,"isMaxInclusive":true}}"#; - let encoded = base64::engine::general_purpose::STANDARD.encode(json.as_bytes()); - assert!(encoded.parse::().is_err()); - } - - #[test] - fn serde_rejects_min_not_inclusive() { - let json = - r#"{"Range":{"min":"","max":"FF","isMinInclusive":false,"isMaxInclusive":false}}"#; - assert!(serde_json::from_str::(json).is_err()); - } - - #[test] - fn from_str_rejects_inverted_range() { - let json = - r#"{"Range":{"min":"FF","max":"","isMinInclusive":true,"isMaxInclusive":false}}"#; - let encoded = base64::engine::general_purpose::STANDARD.encode(json.as_bytes()); - assert!(encoded.parse::().is_err()); - } - - #[test] - fn serde_rejects_inverted_range() { - let json = - r#"{"Range":{"min":"FF","max":"","isMinInclusive":true,"isMaxInclusive":false}}"#; - assert!(serde_json::from_str::(json).is_err()); - } -} diff --git a/sdk/cosmos/azure_data_cosmos/src/hash.rs b/sdk/cosmos/azure_data_cosmos/src/hash.rs index 5a0bce076a6..bdaf991546e 100644 --- a/sdk/cosmos/azure_data_cosmos/src/hash.rs +++ b/sdk/cosmos/azure_data_cosmos/src/hash.rs @@ -2,345 +2,79 @@ // Licensed under the MIT License. use crate::models::PartitionKeyKind; -use crate::murmur_hash::{murmurhash3_128, murmurhash3_32}; -use std::fmt::Write; - -const MAX_STRING_BYTES_TO_APPEND: usize = 100; -pub(crate) const MIN_INCLUSIVE_EFFECTIVE_PARTITION_KEY: &str = ""; -pub(crate) const MAX_EXCLUSIVE_EFFECTIVE_PARTITION_KEY: &str = "FF"; +use azure_data_cosmos_driver::models::{ + effective_partition_key::EffectivePartitionKey as DriverEffectivePartitionKey, + PartitionKeyValue as DriverPartitionKeyValue, PartitionKeyVersion, +}; +use std::fmt; /// A strongly-typed wrapper around the hex-encoded effective partition key string. /// -/// Use [`AsRef`] to obtain the underlying string when passing to APIs -/// that accept `&str`. -/// -/// Ordering is lexicographic on the underlying hex string. This is correct because: -/// - All actual EPK hash values are uppercase hex strings of consistent length -/// - The sentinel MAX ("FF") sorts after all real hashes by the Cosmos DB EPK space design -/// - The sentinel MIN ("") sorts before everything +/// This SDK type wraps the driver's canonical EPK implementation while keeping +/// the SDK's public API surface explicit and stable. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct EffectivePartitionKey(String); +pub struct EffectivePartitionKey(DriverEffectivePartitionKey); impl EffectivePartitionKey { /// Returns the underlying string representation. pub fn as_str(&self) -> &str { - &self.0 + self.0.as_str() + } +} + +impl fmt::Display for EffectivePartitionKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) } } impl AsRef for EffectivePartitionKey { fn as_ref(&self) -> &str { - &self.0 + self.0.as_ref() } } impl From for EffectivePartitionKey { - fn from(s: String) -> Self { - Self(s) + fn from(value: String) -> Self { + Self(DriverEffectivePartitionKey::from(value)) } } impl From<&str> for EffectivePartitionKey { - fn from(s: &str) -> Self { - Self(s.to_owned()) - } -} - -/// Contains all allowed markers for component marker types. -mod component { - pub const UNDEFINED: u8 = 0x00; - pub const NULL: u8 = 0x01; - pub const BOOL_FALSE: u8 = 0x02; - pub const BOOL_TRUE: u8 = 0x03; - pub const NUMBER: u8 = 0x05; - pub const STRING: u8 = 0x08; - pub const INFINITY: u8 = 0xFF; -} - -#[derive(Clone, Debug, PartialEq)] -pub enum InnerPartitionKeyValue { - Null, - Bool(bool), - Number(f64), - String(String), - Infinity, - Undefined, -} - -// `f64` does not implement `Eq`, but in this domain partition key numbers are -// always finite, non-NaN values, so total equality holds. We implement `Eq` -// manually to express this invariant. -impl Eq for InnerPartitionKeyValue {} - -impl InnerPartitionKeyValue { - /// Common hashing writer core: writes type marker + payload (string suffix used by V2). - fn write_for_hashing_core(&self, string_suffix: u8, writer: &mut Vec, truncate: bool) { - match self { - InnerPartitionKeyValue::Bool(true) => writer.push(component::BOOL_TRUE), - InnerPartitionKeyValue::Bool(false) => writer.push(component::BOOL_FALSE), - InnerPartitionKeyValue::Null => writer.push(component::NULL), - InnerPartitionKeyValue::Number(n) => { - writer.push(component::NUMBER); // Number marker - let bytes = n.to_le_bytes(); - writer.extend_from_slice(&bytes); - } - InnerPartitionKeyValue::String(s) => { - writer.push(component::STRING); // String marker - let bytes = s.as_bytes(); - if truncate && bytes.len() > MAX_STRING_BYTES_TO_APPEND { - writer.extend_from_slice(&bytes[..MAX_STRING_BYTES_TO_APPEND]); - } else { - writer.extend_from_slice(bytes); - } - writer.push(string_suffix); - } - InnerPartitionKeyValue::Undefined => writer.push(component::UNDEFINED), - InnerPartitionKeyValue::Infinity => writer.push(component::INFINITY), - } - } - - /// V1 hashing wrapper (string suffix 0x00) - pub fn write_for_hashing_v1(&self, writer: &mut Vec) { - self.write_for_hashing_core(0x00u8, writer, true) - } - - /// V2 hashing wrapper (string suffix 0xFF) - pub fn write_for_hashing_v2(&self, writer: &mut Vec) { - self.write_for_hashing_core(0xFFu8, writer, false) - } - - /// V1 binary encoding (subset required for test cases): - /// * Bool -> marker (0x03 true / 0x02 false) - /// * Number -> marker (0x05) + variable-length 64-bit ordering-preserving encoding - /// * String -> marker (0x08) + each byte+1 (no 0xFF guard) up to 100 or 101 (if truncated) then 0x00 terminator if short - /// * Undefined -> marker (0x00) - /// * Null -> marker (0x01). - pub fn write_for_binary_encoding_v1(&self, writer: &mut Vec) { - match self { - InnerPartitionKeyValue::Bool(true) => writer.push(component::BOOL_TRUE), - InnerPartitionKeyValue::Bool(false) => writer.push(component::BOOL_FALSE), - InnerPartitionKeyValue::Infinity => writer.push(component::INFINITY), - InnerPartitionKeyValue::Number(n) => { - writer.push(component::NUMBER); - let mut payload = encode_double_as_uint64(*n); - // First 8 bits - writer.push((payload >> 56) as u8); - payload <<= 8; - let mut first = true; - let mut byte_to_write: u8 = 0; - while payload != 0 { - if !first { - writer.push(byte_to_write); - } else { - first = false; - } - byte_to_write = ((payload >> 56) as u8) | 0x01; // set continuation bit - payload <<= 7; // consume 7 bits (since we used 7 data bits + 1 flag) - } - writer.push(byte_to_write & 0xFE); // last byte with 0 flag - } - InnerPartitionKeyValue::String(s) => { - writer.push(component::STRING); - let utf8 = s.as_bytes(); - let short = utf8.len() <= MAX_STRING_BYTES_TO_APPEND; - // Use std::cmp to determine truncated write length (include sentinel +1 when longer than max) - let write_len = if short { - utf8.len() - } else { - std::cmp::min(utf8.len(), MAX_STRING_BYTES_TO_APPEND + 1) - }; - for item in utf8.iter().take(write_len) { - let b = item.wrapping_add(1); // unconditional +1 - writer.push(b); - } - if short { - writer.push(0x00); - } - } - InnerPartitionKeyValue::Undefined => writer.push(component::UNDEFINED), - InnerPartitionKeyValue::Null => writer.push(component::NULL), - } - } - - /// Binary encoding used by `_to_hex_encoded_binary_string`. - pub fn write_for_binary_encoding(&self, writer: &mut Vec) { - match self { - InnerPartitionKeyValue::Bool(true) => writer.push(component::BOOL_TRUE), - InnerPartitionKeyValue::Bool(false) => writer.push(component::BOOL_FALSE), - InnerPartitionKeyValue::Infinity => writer.push(component::INFINITY), - InnerPartitionKeyValue::Number(n) => { - writer.push(component::NUMBER); - // use IEEE754 little-endian double representation - writer.extend_from_slice(&n.to_le_bytes()); - } - InnerPartitionKeyValue::String(s) => { - writer.push(component::STRING); - let utf8 = s.as_bytes(); - let size = std::cmp::min(utf8.len(), MAX_STRING_BYTES_TO_APPEND); - let short_string: bool; - let write_len = if size == MAX_STRING_BYTES_TO_APPEND { - short_string = false; - size + 1 - } else { - short_string = true; - size - }; - for item in utf8.iter().take(write_len) { - let mut b = *item; - if b < 0xFF { - b = b.wrapping_add(1); - } - writer.push(b); - } - if short_string { - writer.push(0x00); - } - } - InnerPartitionKeyValue::Undefined => writer.push(component::UNDEFINED), - InnerPartitionKeyValue::Null => writer.push(component::NULL), - } + fn from(value: &str) -> Self { + Self(DriverEffectivePartitionKey::from(value)) } } /// Returns an [`EffectivePartitionKey`] representing the hashed partition key. +/// +/// Versions 1 and 2 map directly to the driver's partition key version enum. +/// Any other version falls back to V2 for forward-compatible behavior. +#[allow(dead_code)] // Currently exercised only by tests; kept for upcoming SDK API. pub fn get_hashed_partition_key_string( - pk_value: &[&InnerPartitionKeyValue], + pk_value: &[DriverPartitionKeyValue], kind: PartitionKeyKind, version: u8, ) -> EffectivePartitionKey { if pk_value.is_empty() { - return EffectivePartitionKey(MIN_INCLUSIVE_EFFECTIVE_PARTITION_KEY.to_string()); - } - if pk_value.len() == 1 && *pk_value[0] == InnerPartitionKeyValue::Infinity { - return EffectivePartitionKey(MAX_EXCLUSIVE_EFFECTIVE_PARTITION_KEY.to_string()); + return EffectivePartitionKey(DriverEffectivePartitionKey::min()); } - let raw = match kind { - PartitionKeyKind::Hash => match version { - 1 => get_effective_partition_key_for_hash_partitioning_v1(pk_value), - 2 => get_effective_partition_key_for_hash_partitioning_v2(pk_value), - _ => { - tracing::warn!( - "Hash partitioning version {} is not supported, falling back to binary encoding.", - version - ); - to_hex_encoded_binary_string(pk_value) - } - }, - PartitionKeyKind::MultiHash => { - // MultiHash is not yet implemented; use the non-hashed binary encoding - // as a deterministic fallback instead of panicking. - tracing::warn!( - "MultiHash partitioning is not yet supported, falling back to binary encoding." - ); - to_hex_encoded_binary_string(pk_value) - } - _ => { + let version = match version { + 1 => PartitionKeyVersion::V1, + 2 => PartitionKeyVersion::V2, + unsupported => { tracing::warn!( - "Unknown partition key kind '{:?}', falling back to binary encoding.", - kind + "Partition key hashing version {} is unsupported in SDK API; defaulting to V2", + unsupported ); - to_hex_encoded_binary_string(pk_value) + PartitionKeyVersion::V2 } }; - EffectivePartitionKey(raw) -} - -/// V2: encode components with `_write_for_hashing_v2`, hash the concatenated bytes, -fn get_effective_partition_key_for_hash_partitioning_v2( - pk_value: &[&InnerPartitionKeyValue], -) -> String { - let mut ms: Vec = Vec::new(); - for comp in pk_value { - comp.write_for_hashing_v2(&mut ms); - } - let hash_128 = murmurhash3_128(&ms, 0); - let mut hash_bytes = hash_128.to_le_bytes(); - hash_bytes.reverse(); - // Reset 2 most significant bits of first byte - hash_bytes[0] &= 0x3F; - bytes_to_hex_upper(&hash_bytes) -} - -/// V1: compute 32-bit murmur hash over concatenated component encodings (suffix 0x00 for strings), -/// convert hash (u32) to f64 (possible precision loss is intentional to mirror other sdks), then binary-encode -/// [hash_value_as_number] + truncated original components using V1 binary rules. -fn get_effective_partition_key_for_hash_partitioning_v1( - pk_value: &[&InnerPartitionKeyValue], -) -> String { - // Build hashing buffer using V1 hashing encoding (truncation is handled by write_for_hashing_v1) - let mut hashing_bytes: Vec = Vec::new(); - for v in pk_value { - v.write_for_hashing_v1(&mut hashing_bytes); - } - - let hash32 = murmurhash3_32(&hashing_bytes, 0u32); - let hash_value_f64 = hash32 as f64; // casts UInt32 -> float (lossy above 2^24) - - // For the binary encoding step, strings must also be truncated to match - // the truncation applied during hashing. - let hash_component = InnerPartitionKeyValue::Number(hash_value_f64); - let truncated_values: Vec = pk_value - .iter() - .map(|v| match v { - InnerPartitionKeyValue::String(s) if s.len() > MAX_STRING_BYTES_TO_APPEND => { - InnerPartitionKeyValue::String(s[..MAX_STRING_BYTES_TO_APPEND].to_string()) - } - other => (*other).clone(), - }) - .collect(); - - let mut components: Vec<&InnerPartitionKeyValue> = - Vec::with_capacity(truncated_values.len() + 1); - components.push(&hash_component); - components.extend(truncated_values.iter()); - - to_hex_encoded_binary_string_v1(&components) -} - -/// Encode multiple components into a binary buffer using V1 rules and return uppercase hex string. -fn to_hex_encoded_binary_string_v1(components: &[&InnerPartitionKeyValue]) -> String { - let mut buffer: Vec = Vec::new(); - for comp in components { - comp.write_for_binary_encoding_v1(&mut buffer); - } - bytes_to_hex_upper(&buffer) -} - -fn encode_double_as_uint64(value: f64) -> u64 { - let value_in_uint64 = u64::from_le_bytes(value.to_le_bytes()); - let mask: u64 = 0x8000_0000_0000_0000; - if value_in_uint64 < mask { - value_in_uint64 ^ mask - } else { - (!value_in_uint64).wrapping_add(1) - } -} - -/// Encode multiple components into a binary buffer and return lowercase hex string. -/// This corresponds to `_to_hex_encoded_binary_string` + `_write_for_binary_encoding`. -fn to_hex_encoded_binary_string(components: &[&InnerPartitionKeyValue]) -> String { - let mut buffer: Vec = Vec::new(); - for comp in components { - comp.write_for_binary_encoding(&mut buffer); - } - bytes_to_hex_lower(&buffer) -} - -fn bytes_to_hex_upper(bytes: &[u8]) -> String { - let mut s = String::with_capacity(bytes.len() * 2); - for b in bytes { - write!(&mut s, "{:02X}", b).unwrap(); - } - s -} -fn bytes_to_hex_lower(bytes: &[u8]) -> String { - let mut s = String::with_capacity(bytes.len() * 2); - for b in bytes { - write!(&mut s, "{:02x}", b).unwrap(); - } - s + EffectivePartitionKey(DriverEffectivePartitionKey::compute( + pk_value, kind, version, + )) } #[cfg(test)] @@ -348,172 +82,71 @@ mod tests { use super::*; #[test] - fn test_empty_pk() { + fn empty_pk_returns_min() { let result = get_hashed_partition_key_string(&[], PartitionKeyKind::Hash, 0); - assert_eq!(result.as_str(), MIN_INCLUSIVE_EFFECTIVE_PARTITION_KEY); + assert_eq!(result.as_str(), ""); } #[test] - fn test_infinity_pk() { - let inf = InnerPartitionKeyValue::Infinity; - let result = get_hashed_partition_key_string(&[&inf], PartitionKeyKind::Hash, 0); - assert_eq!(result.as_str(), MAX_EXCLUSIVE_EFFECTIVE_PARTITION_KEY); + fn single_string_hash_v2_matches_baseline() { + let comp = DriverPartitionKeyValue::from("customer42".to_string()); + let result = get_hashed_partition_key_string(&[comp], PartitionKeyKind::Hash, 2); + assert_eq!(result.as_str(), "19819C94CE42A1654CCC8110539D9589"); } #[test] - fn test_single_string_hash_v2() { - let comp = InnerPartitionKeyValue::String("customer42".to_string()); - let result = get_hashed_partition_key_string(&[&comp], PartitionKeyKind::Hash, 2); - // result should be a hex string of length 32 (16 bytes * 2 chars) - assert_eq!(result.as_str().len(), 32); - assert_eq!( - result.as_str(), - "19819C94CE42A1654CCC8110539D9589", - "Mismatch for component hash" - ) - } - - #[test] - fn test_effective_partition_key_hash_v2() { - // Each entry represents a single-component partition key and the expected - // effective partition key hash (uppercase hex) for V2 hash partitioning. - let thousand_a = "a".repeat(1024); - - // Expected values taken from Java SDK tests. - let cases: Vec<(InnerPartitionKeyValue, &str)> = vec![ + fn effective_partition_key_hash_v2_examples() { + let cases: Vec<(DriverPartitionKeyValue, &str)> = vec![ ( - InnerPartitionKeyValue::String(String::from("")), + DriverPartitionKeyValue::from(String::from("")), "32E9366E637A71B4E710384B2F4970A0", ), ( - InnerPartitionKeyValue::String(String::from("partitionKey")), + DriverPartitionKeyValue::from(String::from("partitionKey")), "013AEFCF77FA271571CF665A58C933F1", ), ( - InnerPartitionKeyValue::String(thousand_a), - "332BDF5512AE49615F32C7D98C2DB86C", - ), - ( - InnerPartitionKeyValue::Null, - "378867E4430E67857ACE5C908374FE16", - ), - ( - InnerPartitionKeyValue::Undefined, - "11622DAA78F835834610ABE56EFF5CB5", - ), - ( - InnerPartitionKeyValue::Bool(true), - "0E711127C5B5A8E4726AC6DD306A3E59", - ), - ( - InnerPartitionKeyValue::Bool(false), - "2FE1BE91E90A3439635E0E9E37361EF2", - ), - ( - InnerPartitionKeyValue::Number(-128f64), - "01DAEDABF913540367FE219B2AD06148", - ), // Java Byte.MIN_VALUE - ( - InnerPartitionKeyValue::Number(127f64), - "0C507ACAC853ECA7977BF4CEFB562A25", - ), // Java Byte.MAX_VALUE - ( - InnerPartitionKeyValue::Number(i64::MIN as f64), - "23D5C6395512BDFEAFADAD15328AD2BB", - ), - ( - InnerPartitionKeyValue::Number(i64::MAX as f64), - "2EDB959178DFCCA18983F89384D1629B", - ), - ( - InnerPartitionKeyValue::Number(i32::MIN as f64), - "0B1660D5233C3171725B30D4A5F4CC1F", - ), - ( - InnerPartitionKeyValue::Number(i32::MAX as f64), - "2D9349D64712AEB5EB1406E2F0BE2725", - ), - ( - InnerPartitionKeyValue::Number(f64::from_bits(0x1)), - "0E6CBA63A280927DE485DEF865800139", - ), // Java Double.MIN_VALUE - ( - InnerPartitionKeyValue::Number(f64::MAX), - "31424D996457102634591FF245DBCC4D", - ), - ( - InnerPartitionKeyValue::Number(5.0), + DriverPartitionKeyValue::from(5.0), "19C08621B135968252FB34B4CF66F811", ), ( - InnerPartitionKeyValue::Number(5.123_124_190_509_124), - "0EF2E2D82460884AF0F6440BE4F726A8", - ), - ( - InnerPartitionKeyValue::String(String::from("redmond")), + DriverPartitionKeyValue::from(String::from("redmond")), "22E342F38A486A088463DFF7838A5963", ), ]; for (component, expected) in &cases { - let actual = get_hashed_partition_key_string(&[component], PartitionKeyKind::Hash, 2); - assert_eq!(actual.as_str(), *expected, "Mismatch for component hash"); + let actual = get_hashed_partition_key_string( + std::slice::from_ref(component), + PartitionKeyKind::Hash, + 2, + ); + assert_eq!(actual.as_str(), *expected, "Mismatch for V2 component hash"); } } #[test] - fn test_effective_partition_key_hash_v2_multiple_keys() { - let component: Vec = vec![ - InnerPartitionKeyValue::Number(5.0), - InnerPartitionKeyValue::String(String::from("redmond")), - InnerPartitionKeyValue::Bool(true), - InnerPartitionKeyValue::Null, - ]; - let expected = "3032DECBE2AB1768D8E0AEDEA35881DF"; - - let refs: Vec<&InnerPartitionKeyValue> = component.iter().collect(); - let actual = get_hashed_partition_key_string(&refs, PartitionKeyKind::Hash, 2); - assert_eq!(actual.as_str(), expected, "Mismatch for component hash"); - } - - #[test] - fn test_effective_partition_key_hash_v1() { - // Expected strings are the direct V1 effective partition key representations (uppercase hex). - let thousand_a = "a".repeat(1024); - - // Expected values taken from Java SDK tests. - let cases: Vec<(InnerPartitionKeyValue, &str)> = vec![ - (InnerPartitionKeyValue::String(String::from("")), "05C1CF33970FF80800"), - (InnerPartitionKeyValue::String(String::from("partitionKey")), "05C1E1B3D9CD2608716273756A756A706F4C667A00"), - (InnerPartitionKeyValue::String(thousand_a), "05C1EB5921F706086262626262626262626262626262626262626262626262626262626262626262626262626262626262626262626262626262626262626262626262626262626262626262626262626262626262626262626262626262626262626262626262626262626200"), - (InnerPartitionKeyValue::Null, "05C1ED45D7475601"), - (InnerPartitionKeyValue::Undefined, "05C1D529E345DC00"), - (InnerPartitionKeyValue::Bool(true), "05C1D7C5A903D803"), - (InnerPartitionKeyValue::Bool(false), "05C1DB857D857C02"), - (InnerPartitionKeyValue::Number(-128f64), "05C1D73349F54C053FA0"), - (InnerPartitionKeyValue::Number(127f64), "05C1DD539DDFCC05C05FE0"), - (InnerPartitionKeyValue::Number(i64::MIN as f64), "05C1DB35F33D1C053C20"), - (InnerPartitionKeyValue::Number(i64::MAX as f64), "05C1B799AB2DD005C3E0"), - (InnerPartitionKeyValue::Number(i32::MIN as f64), "05C1DFBF252BCC053E20"), - (InnerPartitionKeyValue::Number(i32::MAX as f64), "05C1E1F503DFB205C1DFFFFFFFFC"), - (InnerPartitionKeyValue::Number(f64::from_bits(0x1)), "05C1E5C91F4D3005800101010101010102"), // Java Double.MIN_VALUE - (InnerPartitionKeyValue::Number(f64::MAX), "05C1CBE367C53005FFEFFFFFFFFFFFFFFE"), + fn effective_partition_key_hash_v1_examples() { + let cases: Vec<(DriverPartitionKeyValue, &str)> = vec![ + ( + DriverPartitionKeyValue::from(String::from("")), + "05C1CF33970FF80800", + ), + ( + DriverPartitionKeyValue::from(String::from("partitionKey")), + "05C1E1B3D9CD2608716273756A756A706F4C667A00", + ), + (DriverPartitionKeyValue::NULL, "05C1ED45D7475601"), + (DriverPartitionKeyValue::from(true), "05C1D7C5A903D803"), ]; for (component, expected) in &cases { - let actual = get_hashed_partition_key_string(&[component], PartitionKeyKind::Hash, 1); - assert_eq!( - actual.as_str(), - *expected, - "Mismatch for V1 component hash (enable test after implementation)" - ); - // unspecified version defaults to V1 - let actual = get_hashed_partition_key_string(&[component], PartitionKeyKind::Hash, 1); - assert_eq!( - actual.as_str(), - *expected, - "Mismatch for V1 component hash (enable test after implementation)" + let actual = get_hashed_partition_key_string( + std::slice::from_ref(component), + PartitionKeyKind::Hash, + 1, ); + assert_eq!(actual.as_str(), *expected, "Mismatch for V1 component hash"); } } } diff --git a/sdk/cosmos/azure_data_cosmos/src/lib.rs b/sdk/cosmos/azure_data_cosmos/src/lib.rs index 668b8dff49a..2b75cce0ce4 100644 --- a/sdk/cosmos/azure_data_cosmos/src/lib.rs +++ b/sdk/cosmos/azure_data_cosmos/src/lib.rs @@ -11,9 +11,7 @@ mod connection_string; pub mod constants; mod credential; mod feed; -mod feed_range; pub mod options; -mod partition_key; pub mod query; mod session_helpers; @@ -27,12 +25,20 @@ pub use clients::CosmosClientBuilder; pub use account_endpoint::CosmosAccountEndpoint; pub use account_reference::CosmosAccountReference; +#[doc(inline)] +pub use azure_data_cosmos_driver::models::ContinuationToken; +#[doc(inline)] +pub use azure_data_cosmos_driver::models::FeedRange; +#[doc(inline)] +pub use azure_data_cosmos_driver::models::PartitionKey; +#[doc(inline)] +pub use azure_data_cosmos_driver::models::PartitionKeyValue; pub use clients::ThroughputPoller; pub use connection_string::*; pub use credential::CosmosCredential; +pub use hash::EffectivePartitionKey; pub use models::{BatchResponse, CosmosDiagnosticsContext, ItemResponse, ResourceResponse}; pub use options::*; -pub use partition_key::*; pub use query::Query; pub use routing_strategy::RoutingStrategy; pub use transactional_batch::{ @@ -41,12 +47,10 @@ pub use transactional_batch::{ }; pub use feed::{FeedItemIterator, FeedPage, FeedPageIterator, QueryFeedPage}; -pub use feed_range::FeedRange; mod driver_bridge; #[cfg(feature = "fault_injection")] pub mod fault_injection; mod hash; -mod murmur_hash; mod region_proximity; pub mod regions; mod routing_strategy; diff --git a/sdk/cosmos/azure_data_cosmos/src/murmur_hash.rs b/sdk/cosmos/azure_data_cosmos/src/murmur_hash.rs deleted file mode 100644 index 7ed52d61d94..00000000000 --- a/sdk/cosmos/azure_data_cosmos/src/murmur_hash.rs +++ /dev/null @@ -1,283 +0,0 @@ -// The MIT License (MIT) -// Copyright (c) 2023 Microsoft Corporation -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all -// copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. -// -// Implementation of a variation of MurmurHash3 128-bit hash function. -// -// MurmurHash is a non-cryptographic hash function suitable for general hash-based lookup. The name comes from two basic -// operations, multiply (MU) and rotate (R), used in its inner loop. Unlike cryptographic hash functions, it is not -// specifically designed to be difficult to reverse by an adversary, making it unsuitable for cryptographic purposes. -// -// This contains a rust port of the 128-bit hash function from Austin Appleby's original C++ code in SMHasher. -// -// This is public domain code with no copyrights. From home page of -// SMHasher: -// "All MurmurHash versions are public domain software, and the author disclaims all copyright to their code." - -use std::convert::TryInto; - -// Helper accessors for 128-bit values encoded in native `u128`. -#[inline] -fn low64(v: u128) -> u64 { - v as u64 -} -#[inline] -fn high64(v: u128) -> u64 { - (v >> 64) as u64 -} - -/// Rotate left 64-bit. -#[inline] -pub fn rotate_left_64(val: u64, shift: u32) -> u64 { - val.rotate_left(shift) -} - -#[inline] -pub fn mix(mut value: u64) -> u64 { - value ^= value >> 33; - value = value.wrapping_mul(0xff51afd7ed558ccd); - value ^= value >> 33; - value = value.wrapping_mul(0xc4ceb9fe1a85ec53); - value ^= value >> 33; - value -} - -/// Rust SDK implementation of 128 bit murmurhash3 from Dot Net SDK. To match with other SDKs, It is recommended to -/// do the following with number values, especially floats as other SDKs use Doubles -/// -> bytearray(struct.pack("d", #)) where # represents any number. The d will treat it as a double. -/// MurmurHash3 x64 128-bit implementation -/// `span` is the input bytes, `seed` is a 128-bit value whose low/high 64-bit -/// lanes initialize the internal state. -pub fn murmurhash3_128(span: &[u8], seed: u128) -> u128 { - let c1: u64 = 0x87c37b91114253d5; - let c2: u64 = 0x4cf5ad432745937f; - - let mut h1: u64 = low64(seed); - let mut h2: u64 = high64(seed); - - let mut position = 0usize; - let len = span.len(); - - while position + 16 <= len { - let k1 = u64::from_le_bytes(span[position..position + 8].try_into().unwrap()); - let k2 = u64::from_le_bytes(span[position + 8..position + 16].try_into().unwrap()); - - // k1 - let mut k1 = k1.wrapping_mul(c1); - k1 = rotate_left_64(k1, 31); - k1 = k1.wrapping_mul(c2); - h1 ^= k1; - h1 = rotate_left_64(h1, 27); - h1 = h1.wrapping_add(h2); - h1 = h1.wrapping_mul(5).wrapping_add(0x52dce729); - - // k2 - let mut k2 = k2.wrapping_mul(c2); - k2 = rotate_left_64(k2, 33); - k2 = k2.wrapping_mul(c1); - h2 ^= k2; - h2 = rotate_left_64(h2, 31); - h2 = h2.wrapping_add(h1); - h2 = h2.wrapping_mul(5).wrapping_add(0x38495ab5); - - position += 16; - } - - // tail - let mut k1: u64 = 0; - let mut k2: u64 = 0; - let n = len & 15; - - if n >= 15 { - k2 ^= (span[position + 14] as u64) << 48; - } - if n >= 14 { - k2 ^= (span[position + 13] as u64) << 40; - } - if n >= 13 { - k2 ^= (span[position + 12] as u64) << 32; - } - if n >= 12 { - k2 ^= (span[position + 11] as u64) << 24; - } - if n >= 11 { - k2 ^= (span[position + 10] as u64) << 16; - } - if n >= 10 { - k2 ^= (span[position + 9] as u64) << 8; - } - if n >= 9 { - k2 ^= span[position + 8] as u64; - } - - if k2 != 0 { - k2 = k2.wrapping_mul(c2); - k2 = rotate_left_64(k2, 33); - k2 = k2.wrapping_mul(c1); - h2 ^= k2; - } - - if n >= 8 { - k1 ^= (span[position + 7] as u64) << 56; - } - if n >= 7 { - k1 ^= (span[position + 6] as u64) << 48; - } - if n >= 6 { - k1 ^= (span[position + 5] as u64) << 40; - } - if n >= 5 { - k1 ^= (span[position + 4] as u64) << 32; - } - if n >= 4 { - k1 ^= (span[position + 3] as u64) << 24; - } - if n >= 3 { - k1 ^= (span[position + 2] as u64) << 16; - } - if n >= 2 { - k1 ^= (span[position + 1] as u64) << 8; - } - if n >= 1 { - k1 ^= span[position] as u64; - k1 = k1.wrapping_mul(c1); - k1 = rotate_left_64(k1, 31); - k1 = k1.wrapping_mul(c2); - h1 ^= k1; - } - - // finalization - h1 ^= len as u64; - h2 ^= len as u64; - h1 = h1.wrapping_add(h2); - h2 = h2.wrapping_add(h1); - h1 = mix(h1); - h2 = mix(h2); - h1 = h1.wrapping_add(h2); - h2 = h2.wrapping_add(h1); - - ((h2 as u128) << 64) | (h1 as u128) -} - -/// MurmurHash3 32-bit implementation -pub fn murmurhash3_32(data: &[u8], seed: u32) -> u32 { - let c1: u32 = 0xcc9e2d51; - let c2: u32 = 0x1b873593; - let length: u32 = data.len() as u32; - let mut h1: u32 = seed; - let rounded_end = (length & 0xfffffffc) as usize; // round down to 4 byte block - - let mut i = 0usize; - while i < rounded_end { - let k1 = (data[i] as u32) - | ((data[i + 1] as u32) << 8) - | ((data[i + 2] as u32) << 16) - | ((data[i + 3] as u32) << 24); - i += 4; - - let mut k1 = k1.wrapping_mul(c1); - k1 = k1.rotate_left(15); - k1 = k1.wrapping_mul(c2); - - h1 ^= k1; - h1 = h1.rotate_left(13); - h1 = h1.wrapping_mul(5).wrapping_add(0xe6546b64); - } - - // tail - let mut k1_tail: u32 = 0; - let tail = (length & 0x03) as usize; - if tail == 3 { - k1_tail ^= (data[rounded_end + 2] as u32) << 16; - } - if tail >= 2 { - k1_tail ^= (data[rounded_end + 1] as u32) << 8; - } - if tail >= 1 { - k1_tail ^= data[rounded_end] as u32; - k1_tail = k1_tail.wrapping_mul(c1); - k1_tail = k1_tail.rotate_left(15); - k1_tail = k1_tail.wrapping_mul(c2); - h1 ^= k1_tail; - } - - // finalization - h1 ^= length; - h1 ^= h1 >> 16; - h1 = h1.wrapping_mul(0x85ebca6b); - h1 ^= h1 >> 13; - h1 = h1.wrapping_mul(0xc2b2ae35); - h1 ^= h1 >> 16; - - h1 -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_murmurhash3_128_float() { - let value: f64 = 374.0; - let ba_le: [u8; 8] = value.to_le_bytes(); - let ba_vec: Vec = ba_le.to_vec(); - let seed: u128 = 0; - let h = murmurhash3_128(&ba_vec, seed); - let _u128 = h; - // known results - assert_eq!(low64(h), 16628891264555680919); - assert_eq!(high64(h), 12953474369317462); - } - - #[test] - fn test_murmurhash3_128_string() { - let s = "sample-test"; - let bytes = s.as_bytes(); - let seed: u128 = 0; - let h = murmurhash3_128(bytes, seed); - let _u128 = h; - // known results - assert_eq!(low64(h), 9863137013172825203); - assert_eq!(high64(h), 15859947107521786564); - } - - #[test] - fn test_murmurhash3_32_float() { - let value: f64 = 374.0; - let ba_le: [u8; 8] = value.to_le_bytes(); - let ba_vec: Vec = ba_le.to_vec(); - let seed: u32 = 0; - let h = murmurhash3_32(&ba_vec, seed); - let _u32 = h; - // known results - assert_eq!(h, 3717946798); - } - - #[test] - fn test_murmurhash3_32_string() { - let s = "sample-test"; - let bytes = s.as_bytes(); - let seed: u32 = 0; - let h = murmurhash3_32(bytes, seed); - let _u32 = h; - // known results - assert_eq!(h, 2066086989); - } -} diff --git a/sdk/cosmos/azure_data_cosmos/src/options/mod.rs b/sdk/cosmos/azure_data_cosmos/src/options/mod.rs index 7ed78aae8c1..92d13971cd9 100644 --- a/sdk/cosmos/azure_data_cosmos/src/options/mod.rs +++ b/sdk/cosmos/azure_data_cosmos/src/options/mod.rs @@ -2,6 +2,7 @@ // Licensed under the MIT License. use crate::models::ThroughputProperties; +use crate::ContinuationToken; use std::fmt; use std::fmt::Display; @@ -251,6 +252,22 @@ pub struct QueryOptions { /// Session token for session-consistent queries. pub session_token: Option, + + /// Maximum number of items to return per page. + /// + /// When set, the server will return at most this many items in each response page. + /// If not set, the server uses its default page size. + /// + /// This is a _hint_ to the server, not a client-side guarantee of the maximum returned page size. + /// In a cross-partition query, each partition may return up to this many items, + /// so the total page size could be up to this value times the number of partitions involved. + /// Some server operations may return fewer, or even more, items than this value based on internal heuristics. + pub max_item_count: Option, + + /// Continuation token from a prior page iterator, used to resume the query. + /// + /// See [`FeedPageIterator::to_continuation_token`](crate::FeedPageIterator::to_continuation_token). + pub continuation_token: Option, } impl QueryOptions { @@ -265,6 +282,18 @@ impl QueryOptions { self.operation = operation; self } + + /// Sets the maximum number of items to return per page. + pub fn with_max_item_count(mut self, max_item_count: u32) -> Self { + self.max_item_count = Some(max_item_count); + self + } + + /// Sets a continuation token to resume the query at a previous position. + pub fn with_continuation_token(mut self, continuation_token: ContinuationToken) -> Self { + self.continuation_token = Some(continuation_token); + self + } } /// Options to be passed to [`ContainerClient::read()`](crate::clients::ContainerClient::read()). diff --git a/sdk/cosmos/azure_data_cosmos/src/partition_key.rs b/sdk/cosmos/azure_data_cosmos/src/partition_key.rs deleted file mode 100644 index e418f7c7be2..00000000000 --- a/sdk/cosmos/azure_data_cosmos/src/partition_key.rs +++ /dev/null @@ -1,617 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -use std::borrow::Cow; - -use azure_core::http::headers::{AsHeaders, HeaderName, HeaderValue}; - -use crate::constants; -use crate::hash::{get_hashed_partition_key_string, EffectivePartitionKey, InnerPartitionKeyValue}; -use crate::models::PartitionKeyKind; - -/// Specifies a partition key value, usually used when querying a specific partition. -/// -/// # Specifying a partition key -/// -/// Most APIs that require a partition key will accept `impl Into`, giving you a few options on how to specify your partition key. -/// -/// A single, non-hierarchical, partition key can be specified using the underlying type itself: -/// -/// ```rust,no_run -/// # use azure_data_cosmos::clients::ContainerClient; -/// # let container_client: ContainerClient = panic!("this is a non-running example"); -/// container_client.query_items::( -/// "SELECT * FROM c", -/// "a single string partition key", -/// None).unwrap(); -/// container_client.query_items::( -/// "SELECT * FROM c", -/// 42, // A numeric partition key -/// None).unwrap(); -/// ``` -/// -/// Hierarchical partition keys can be specified using tuples: -/// -/// ```rust,no_run -/// # use azure_data_cosmos::clients::ContainerClient; -/// # let container_client: ContainerClient = panic!("this is a non-running example"); -/// container_client.query_items::( -/// "SELECT * FROM c", -/// ("parent", "child"), -/// None).unwrap(); -/// ``` -/// -/// Null values can be represented in one of two ways. -/// First, you can use the value [`PartitionKey::NULL`]: -/// -/// ```rust,no_run -/// # use azure_data_cosmos::{clients::ContainerClient, PartitionKey}; -/// # let container_client: ContainerClient = panic!("this is a non-running example"); -/// container_client.query_items::( -/// "SELECT * FROM c", -/// PartitionKey::NULL, -/// None).unwrap(); -/// container_client.query_items::( -/// "SELECT * FROM c", -/// ("a", PartitionKey::NULL, "b"), // A null value within a hierarchical partition key. -/// None).unwrap(); -/// ``` -/// -/// Undefined partition key values can be represented using [`PartitionKey::UNDEFINED`]. -/// This is used to refer to items where the partition key property is absent from the document. -/// This is distinct from `null` (where the property exists but has a JSON null value). -/// -/// ```rust,no_run -/// # use azure_data_cosmos::{clients::ContainerClient, PartitionKey}; -/// # let container_client: ContainerClient = panic!("this is a non-running example"); -/// # async { -/// container_client.read_item::( -/// PartitionKey::UNDEFINED, -/// "item_without_partition_key_property", -/// None).await.unwrap(); -/// # }; -/// ``` -/// -/// Or, if you have an [`Option`], for some `T` that is valid as a partition key, it will automatically be serialized as `null` if it has the value [`Option::None`]: -/// -/// ```rust,no_run -/// # use azure_data_cosmos::clients::ContainerClient; -/// # let container_client: ContainerClient = panic!("this is a non-running example"); -/// let my_partition_key: Option = None; -/// container_client.query_items::( -/// "SELECT * FROM c", -/// my_partition_key, -/// None).unwrap(); -/// ``` -/// -/// If you want to create your [`PartitionKey`] and store it in a variable, use [`PartitionKey::from()`] -/// -/// ```rust -/// # use azure_data_cosmos::PartitionKey; -/// let partition_key_1 = PartitionKey::from("simple_string"); -/// let partition_key_2 = PartitionKey::from(("parent", "child", 42)); -/// ``` -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct PartitionKey(Vec); - -impl PartitionKey { - /// A single null partition key value, which can be used as the sole partition key or as part of a hierarchical partition key. - pub const NULL: PartitionKeyValue = PartitionKeyValue(InnerPartitionKeyValue::Null); - - /// A single undefined partition key value, used to target items where the partition key property is absent from the document. - /// - /// This is distinct from [`PartitionKey::NULL`], which targets items where the partition key property exists but has a JSON `null` value. - /// An undefined value is serialized as `{}` (an empty JSON object) in the partition key header. - /// For example, a single `UNDEFINED` value serializes to `[{}]`. - pub const UNDEFINED: PartitionKeyValue = PartitionKeyValue(InnerPartitionKeyValue::Undefined); - - /// An empty list of partition key values, which is used to signal a cross-partition query, when querying a container. - pub const EMPTY: PartitionKey = PartitionKey(Vec::new()); - - /// Converts this SDK partition key into the driver's equivalent type. - pub(crate) fn into_driver_partition_key( - self, - ) -> azure_data_cosmos_driver::models::PartitionKey { - use azure_data_cosmos_driver::models::{ - PartitionKey as DriverPK, PartitionKeyValue as DriverPKV, - }; - - let driver_values: Vec = self - .0 - .into_iter() - .map(|v| match v.0 { - InnerPartitionKeyValue::String(s) => DriverPKV::from(s), - InnerPartitionKeyValue::Number(n) => DriverPKV::from(n), - InnerPartitionKeyValue::Bool(b) => DriverPKV::from(b), - InnerPartitionKeyValue::Null => DriverPKV::from(Option::::None), - InnerPartitionKeyValue::Undefined => DriverPKV::undefined(), - InnerPartitionKeyValue::Infinity => { - // Infinity is an internal sentinel for EPK boundary calculations - // and cannot be constructed via the public SDK API. - // Mapping to Null as a defensive fallback; this path should be unreachable. - DriverPKV::from(Option::::None) - } - }) - .collect(); - - DriverPK::from(driver_values) - } - - /// Returns a hex string representation of the partition key hash. - /// - /// # Arguments - /// * `kind` - The partition key kind (Hash or MultiHash) - /// * `version` - The hash version (1 or 2) - /// - /// # Returns - /// An `EffectivePartitionKey` representing the hashed partition key - pub fn get_hashed_partition_key_string( - &self, - kind: PartitionKeyKind, - version: u8, - ) -> EffectivePartitionKey { - let inner_values: Vec<&InnerPartitionKeyValue> = self.0.iter().map(|v| &v.0).collect(); - get_hashed_partition_key_string(&inner_values, kind, version) - } -} - -impl AsHeaders for PartitionKey { - type Error = azure_core::Error; - type Iter = std::iter::Once<(HeaderName, HeaderValue)>; - - fn as_headers(&self) -> Result { - // We have to do some manual JSON serialization here. - // The partition key is sent in an HTTP header, when used to set the partition key for a query. - // It's not safe to use non-ASCII characters in HTTP headers, and serde_json will not escape non-ASCII characters if they are otherwise valid as UTF-8. - // So, we do some conversion by hand, with the help of Rust's own `encode_utf16` method which gives us the necessary code points for non-ASCII values, and produces surrogate pairs as needed. - - // Quick shortcut for empty partition keys list, which also prevents a bug when we pop the trailing comma for an empty list. - if self.0.is_empty() { - // An empty partition key means a cross partition query - return Ok(std::iter::once(( - constants::QUERY_ENABLE_CROSS_PARTITION, - HeaderValue::from_static("True"), - ))); - } - - let mut json = String::new(); - let mut utf_buf = [0; 2]; // A buffer for encoding UTF-16 characters. - json.push('['); - for key in &self.0 { - match key.0 { - InnerPartitionKeyValue::Undefined => json.push_str("{}"), - InnerPartitionKeyValue::Null => json.push_str("null"), - InnerPartitionKeyValue::Bool(b) => json.push_str(if b { "true" } else { "false" }), - InnerPartitionKeyValue::String(ref string_key) => { - json.push('"'); - for char in string_key.chars() { - match char { - '\x08' => json.push_str(r#"\b"#), - '\x0c' => json.push_str(r#"\f"#), - '\n' => json.push_str(r#"\n"#), - '\r' => json.push_str(r#"\r"#), - '\t' => json.push_str(r#"\t"#), - '"' => json.push_str(r#"\""#), - '\\' => json.push('\\'), - c if c.is_ascii() => json.push(c), - c => { - let encoded = c.encode_utf16(&mut utf_buf); - for code_unit in encoded { - json.push_str(&format!(r#"\u{:04x}"#, code_unit)); - } - } - } - } - json.push('"'); - } - InnerPartitionKeyValue::Number(ref num) => { - json.push_str(&num.to_string()); - } - InnerPartitionKeyValue::Infinity => json.push_str("\"Infinity\""), - } - - json.push(','); - } - - // Pop the trailing ',' (only if we actually wrote any values) - if json.ends_with(',') { - json.pop(); - } - json.push(']'); - - Ok(std::iter::once(( - constants::PARTITION_KEY, - HeaderValue::from_cow(json), - ))) - } -} - -/// Represents a value for a single partition key. -/// -/// You shouldn't need to construct this type directly. The various implementations of [`Into`] will handle it for you. -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct PartitionKeyValue(InnerPartitionKeyValue); - -impl From for PartitionKeyValue { - fn from(value: InnerPartitionKeyValue) -> Self { - PartitionKeyValue(value) - } -} - -impl From<&'static str> for PartitionKeyValue { - fn from(value: &'static str) -> Self { - InnerPartitionKeyValue::String(value.to_string()).into() - } -} - -impl From for PartitionKeyValue { - fn from(value: String) -> Self { - InnerPartitionKeyValue::String(value).into() - } -} - -impl From<&String> for PartitionKeyValue { - fn from(value: &String) -> Self { - InnerPartitionKeyValue::String(value.clone()).into() - } -} - -impl From> for PartitionKeyValue { - fn from(value: Cow<'static, str>) -> Self { - InnerPartitionKeyValue::String(value.into_owned()).into() - } -} - -macro_rules! impl_from_number { - ($source_type: ty) => { - impl From<$source_type> for PartitionKeyValue { - fn from(value: $source_type) -> Self { - InnerPartitionKeyValue::Number(value as f64).into() - } - } - }; -} - -impl_from_number!(i16); -impl_from_number!(i32); -impl_from_number!(i64); -impl_from_number!(i8); -impl_from_number!(isize); -impl_from_number!(u16); -impl_from_number!(u32); -impl_from_number!(u64); -impl_from_number!(u8); -impl_from_number!(usize); - -impl From for PartitionKeyValue { - /// Creates a [`PartitionKeyValue`] from an `f32`. - /// - /// WARNING: This extends the precision of the value from `f32` to `f64`. - /// - /// # Panics - /// - /// This method panics if given an Infinite or NaN value. - fn from(value: f32) -> Self { - assert!( - !value.is_infinite() && !value.is_nan(), - "value should be a non-infinite number" - ); - InnerPartitionKeyValue::Number(value as f64).into() - } -} - -impl From for PartitionKeyValue { - /// Creates a [`PartitionKeyValue`] from an `f64`. - /// - /// # Panics - /// - /// This method panics if given an Infinite or NaN value. - fn from(value: f64) -> Self { - assert!( - !value.is_infinite() && !value.is_nan(), - "value should be a non-infinite number" - ); - InnerPartitionKeyValue::Number(value).into() - } -} - -impl> From> for PartitionKeyValue { - fn from(value: Option) -> Self { - match value { - Some(t) => t.into(), - None => InnerPartitionKeyValue::Null.into(), - } - } -} - -impl From<()> for PartitionKey { - fn from(_: ()) -> Self { - PartitionKey::EMPTY - } -} - -impl From> for PartitionKey { - /// Creates a [`PartitionKey`] from a vector of [`PartitionKeyValue`]s. - /// - /// This is useful when the partition key structure is determined at runtime, - /// such as when working with multiple containers with different schemas or - /// building partition keys from configuration. - /// - /// # Panics - /// - /// Panics if the vector contains more than 3 elements, as Cosmos DB supports - /// a maximum of 3 hierarchical partition key levels. - /// - /// # Examples - /// - /// ```rust - /// use azure_data_cosmos::{PartitionKey, PartitionKeyValue}; - /// - /// // Single-level partition key - /// let keys = vec![PartitionKeyValue::from("tenant1")]; - /// let partition_key = PartitionKey::from(keys); - /// - /// // Multi-level partition key built at runtime - /// let mut keys = vec![PartitionKeyValue::from("tenant1")]; - /// keys.push(PartitionKeyValue::from("region1")); - /// let partition_key = PartitionKey::from(keys); - /// ``` - fn from(values: Vec) -> Self { - assert!( - values.len() <= 3, - "Partition keys can have at most 3 levels, got {}", - values.len() - ); - PartitionKey(values) - } -} - -impl> From for PartitionKey { - fn from(value: T) -> Self { - PartitionKey(vec![value.into()]) - } -} - -macro_rules! impl_from_tuple { - ($($n:tt $name:ident)*) => { - impl<$($name: Into),*> From<($($name,)*)> for PartitionKey { - fn from(value: ($($name,)*)) -> Self { - PartitionKey(vec![$( - value.$n.into() - ),*]) - } - } - }; -} - -// CosmosDB hierarchical partition keys are up to 3 levels: -// https://learn.microsoft.com/en-us/azure/cosmos-db/hierarchical-partition-keys -impl_from_tuple!(0 A 1 B); -impl_from_tuple!(0 A 1 B 2 C); - -#[cfg(test)] -mod tests { - use crate::{constants, PartitionKey, PartitionKeyValue}; - use azure_core::http::headers::AsHeaders; - - fn key_to_string(v: impl Into) -> String { - let key = v.into(); - let mut headers_iter = key.as_headers().unwrap(); - let (name, value) = headers_iter.next().unwrap(); - assert_eq!(constants::PARTITION_KEY, name); - value.as_str().into() - } - - /// Validates that a given value is `impl Into` and works as-expected. - fn key_to_single_string_partition_key(v: Option>) -> Option { - v.map(|k| key_to_string(k)) - } - - #[test] - pub fn static_str() { - assert_eq!(key_to_string("my_partition_key"), r#"["my_partition_key"]"#); - assert_eq!( - key_to_single_string_partition_key(Some("my_partition_key")).as_deref(), - Some(r#"["my_partition_key"]"#) - ); - } - - #[test] - pub fn integers() { - assert_eq!(key_to_string(42u8), r#"[42]"#); - assert_eq!(key_to_string(42u16), r#"[42]"#); - assert_eq!(key_to_string(42u32), r#"[42]"#); - assert_eq!(key_to_string(42u64), r#"[42]"#); - assert_eq!(key_to_string(42usize), r#"[42]"#); - assert_eq!(key_to_string(42i8), r#"[42]"#); - assert_eq!(key_to_string(42i16), r#"[42]"#); - assert_eq!(key_to_string(42i32), r#"[42]"#); - assert_eq!(key_to_string(42i64), r#"[42]"#); - assert_eq!(key_to_string(42isize), r#"[42]"#); - } - - #[test] - pub fn floats() { - // The f32 gets up-cast to f64, which results in a rounding issue. - // It's serde_json's default behavior, so we expect it, even if it isn't ideal. - assert_eq!(key_to_string(4.2f32), r#"[4.199999809265137]"#); - assert_eq!(key_to_string(4.2f64), r#"[4.2]"#); - } - - #[test] - pub fn options() { - let some: Option<&str> = Some("my_partition_key"); - let none: Option<&str> = None; - assert_eq!(key_to_string(some), r#"["my_partition_key"]"#); - assert_eq!(key_to_string(none), r#"[null]"#); - } - - #[test] - fn from_vec_empty() { - let keys: Vec = vec![]; - let partition_key = PartitionKey::from(keys); - assert_eq!(Vec::::new(), partition_key.0); - - let mut headers_iter = partition_key.as_headers().unwrap(); - let (name, value) = headers_iter.next().unwrap(); - assert_eq!(constants::QUERY_ENABLE_CROSS_PARTITION, name); - assert_eq!("True", value.as_str()); - } - - #[test] - fn from_vec_single() { - let keys = vec![PartitionKeyValue::from("tenant1")]; - let partition_key = PartitionKey::from(keys); - assert_eq!(key_to_string(partition_key), r#"["tenant1"]"#); - } - - #[test] - fn from_vec_double() { - let keys = vec![ - PartitionKeyValue::from("tenant1"), - PartitionKeyValue::from("region1"), - ]; - let partition_key = PartitionKey::from(keys); - assert_eq!(key_to_string(partition_key), r#"["tenant1","region1"]"#); - } - - #[test] - fn from_vec_triple() { - let keys = vec![ - PartitionKeyValue::from("tenant1"), - PartitionKeyValue::from("region1"), - PartitionKeyValue::from("user1"), - ]; - let partition_key = PartitionKey::from(keys); - assert_eq!( - key_to_string(partition_key), - r#"["tenant1","region1","user1"]"# - ); - } - - #[test] - fn from_vec_mixed_types() { - let keys = vec![ - PartitionKeyValue::from("tenant1"), - PartitionKeyValue::from(42i64), - PartitionKeyValue::from(123.45f64), - ]; - let partition_key = PartitionKey::from(keys); - assert_eq!(key_to_string(partition_key), r#"["tenant1",42,123.45]"#); - } - - #[test] - #[should_panic(expected = "Partition keys can have at most 3 levels, got 4")] - fn from_vec_too_many() { - let keys = vec![ - PartitionKeyValue::from("a"), - PartitionKeyValue::from("b"), - PartitionKeyValue::from("c"), - PartitionKeyValue::from("d"), - ]; - let _partition_key = PartitionKey::from(keys); - } - - #[test] - fn null_value() { - assert_eq!(key_to_string(PartitionKey::NULL), r#"[null]"#); - assert_eq!( - key_to_string((PartitionKey::NULL, PartitionKey::NULL, PartitionKey::NULL)), - r#"[null,null,null]"# - ); - } - - #[test] - pub fn non_ascii_string() { - let key = PartitionKey::from("smile 😀"); - assert_eq!(key_to_string(key), r#"["smile \ud83d\ude00"]"#); - } - - #[test] - pub fn tuple() { - assert_eq!( - key_to_string((42u8, "my_partition_key", PartitionKey::NULL)), - r#"[42,"my_partition_key",null]"# - ); - } - - #[test] - pub fn empty() { - let partition_key = PartitionKey::from(()); - assert_eq!(Vec::::new(), partition_key.0); - - let mut headers_iter = partition_key.as_headers().unwrap(); - let (name, value) = headers_iter.next().unwrap(); - assert_eq!(constants::QUERY_ENABLE_CROSS_PARTITION, name); - assert_eq!("True", value.as_str()); - } - - /// Helper to get the partition key header value (not cross-partition header). - fn key_to_pk_header(v: impl Into) -> (String, String) { - let key = v.into(); - let mut headers_iter = key.as_headers().unwrap(); - let (name, value) = headers_iter.next().unwrap(); - (name.as_str().to_string(), value.as_str().to_string()) - } - - #[test] - fn undefined_single() { - // A single UNDEFINED value should produce [{}] via the partition key header, - // where {} is the wire representation of an undefined partition key component. - let (name, value) = key_to_pk_header(PartitionKey::UNDEFINED); - assert_eq!(constants::PARTITION_KEY.as_str(), name); - assert_eq!("[{}]", value); - } - - #[test] - fn undefined_all_in_hierarchical() { - // All UNDEFINED values in a hierarchical key should produce [{},{}]. - let (name, value) = key_to_pk_header((PartitionKey::UNDEFINED, PartitionKey::UNDEFINED)); - assert_eq!(constants::PARTITION_KEY.as_str(), name); - assert_eq!("[{},{}]", value); - } - - #[test] - fn undefined_mixed_with_values() { - // UNDEFINED values should be serialized as {} in the JSON array. - assert_eq!( - key_to_string(("parent", PartitionKey::UNDEFINED)), - r#"["parent",{}]"# - ); - assert_eq!( - key_to_string((PartitionKey::UNDEFINED, "child")), - r#"[{},"child"]"# - ); - } - - #[test] - fn undefined_distinct_from_null() { - // UNDEFINED produces [{}] while NULL produces [null]. - let (undef_name, undef_value) = key_to_pk_header(PartitionKey::UNDEFINED); - let null_value = key_to_string(PartitionKey::NULL); - assert_eq!(constants::PARTITION_KEY.as_str(), undef_name); - assert_eq!("[{}]", undef_value); - assert_eq!("[null]", null_value); - } - - #[test] - fn undefined_distinct_from_empty() { - // UNDEFINED sends the partition key header with `[{}]`, while EMPTY sends the cross-partition header. - let (undef_name, undef_value) = key_to_pk_header(PartitionKey::UNDEFINED); - assert_eq!(constants::PARTITION_KEY.as_str(), undef_name); - assert_eq!("[{}]", undef_value); - - let empty = PartitionKey::EMPTY; - let mut headers_iter = empty.as_headers().unwrap(); - let (empty_name, empty_value) = headers_iter.next().unwrap(); - assert_eq!(constants::QUERY_ENABLE_CROSS_PARTITION, empty_name); - assert_eq!("True", empty_value.as_str()); - } - - #[test] - fn undefined_in_vec() { - let keys = vec![PartitionKeyValue::from("tenant1"), PartitionKey::UNDEFINED]; - let partition_key = PartitionKey::from(keys); - assert_eq!(key_to_string(partition_key), r#"["tenant1",{}]"#); - } -} diff --git a/sdk/cosmos/azure_data_cosmos/src/query/mod.rs b/sdk/cosmos/azure_data_cosmos/src/query.rs similarity index 82% rename from sdk/cosmos/azure_data_cosmos/src/query/mod.rs rename to sdk/cosmos/azure_data_cosmos/src/query.rs index 44befb366fa..4aebb661955 100644 --- a/sdk/cosmos/azure_data_cosmos/src/query/mod.rs +++ b/sdk/cosmos/azure_data_cosmos/src/query.rs @@ -3,11 +3,50 @@ //! Models and components used to represents and execute queries. +use azure_data_cosmos_driver::models::{FeedRange, OperationTarget, PartitionKey}; use serde::Serialize; -pub(crate) mod executor; +/// Represents the scope of a query, which determines which partitions it targets. +/// +/// The Cosmos DB backend can only execute queries against a single physical partition at a time, +/// so it is important to choose the appropriate scope for your query to ensure it is executed efficiently. +/// Queries that cross physical partition boundaries require the client to fan out the query to +/// multiple partitions and aggregate the results, which can be expensive and slow for large datasets. +#[derive(Clone)] +pub enum QueryScope { + Partition(PartitionKey), + FeedRange(FeedRange), +} + +impl QueryScope { + /// Returns a [`QueryScope`] that represents the given partition key, which is used for targeting a specific partition in the container. + pub fn partition(pk: impl Into) -> Self { + Self::Partition(pk.into()) + } + + /// Returns a [`QueryScope`] that represents the given feed range, which can be used for partition-specific or cross-partition queries depending on the feed range provided. + /// + /// WARNING: Using a feed range that covers multiple partitions may result in a full scan of those partitions, which can be expensive and slow for large datasets. Use with caution. + pub fn feed_range(fr: FeedRange) -> Self { + Self::FeedRange(fr) + } -pub use executor::QueryExecutor; + /// Returns a [`QueryScope`] that represents the full container, which is used for cross-partition queries. + /// + /// WARNING: Using this query scope may result in a full scan of the container, which can be expensive and slow for large datasets. Use with caution. + pub fn full_container() -> Self { + Self::FeedRange(FeedRange::full()) + } +} + +impl From for OperationTarget { + fn from(value: QueryScope) -> Self { + match value { + QueryScope::Partition(pk) => Self::PartitionKey(pk), + QueryScope::FeedRange(fr) => Self::FeedRange(fr), + } + } +} /// Represents a Cosmos DB Query, with optional parameters. /// diff --git a/sdk/cosmos/azure_data_cosmos/src/query/executor.rs b/sdk/cosmos/azure_data_cosmos/src/query/executor.rs deleted file mode 100644 index bf3d8b0fac7..00000000000 --- a/sdk/cosmos/azure_data_cosmos/src/query/executor.rs +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -//! Query execution implementation. - -use std::{collections::HashMap, sync::Arc}; - -use azure_core::http::headers::{HeaderName, HeaderValue}; -use azure_data_cosmos_driver::{ - models::{CosmosOperation, SessionToken}, - options::OperationOptions as DriverOperationOptions, - CosmosDriver, -}; -use serde::de::DeserializeOwned; - -use crate::{constants, driver_bridge, feed::FeedBody, Query, QueryFeedPage}; - -/// A query executor that sends queries through the Cosmos driver. -/// -/// This executor handles pagination via continuation tokens and works for -/// item queries (with partition key), database queries, and container queries. -/// The `operation_factory` closure produces the appropriate `CosmosOperation` -/// for each page request. -pub struct QueryExecutor { - driver: Arc, - operation_factory: Box CosmosOperation + Send>, - query: Query, - query_body: Option>, - base_options: DriverOperationOptions, - base_headers: HashMap, - session_token: Option, - continuation: Option, - complete: bool, - // Why is our phantom type a function? Because that represents how we _use_ the type T. - // Normally, PhantomData is only Send/Sync if T is, because PhantomData is indicating that while we don't _name_ T in a field, we should act as though we have a field of type T. - // However, we don't store any T values in this, we only RETURN them. - // That means we use a function pointer to indicate that we don't actually operate on T directly, we just return it. - // Because of this, PhantomData T> is Send/Sync even if T isn't (see https://doc.rust-lang.org/stable/nomicon/phantom-data.html#table-of-phantomdata-patterns) - phantom: std::marker::PhantomData T>, -} - -impl QueryExecutor { - pub(crate) fn new( - driver: Arc, - operation_factory: impl Fn() -> CosmosOperation + Send + 'static, - query: Query, - base_options: DriverOperationOptions, - session_token: Option, - ) -> Self { - // Pre-build the static headers that are the same for every page: - // user-provided custom headers + query-specific constants. - let mut base_headers = base_options.custom_headers().cloned().unwrap_or_default(); - base_headers.insert(constants::QUERY.clone(), HeaderValue::from_static("True")); - base_headers.insert( - azure_core::http::headers::CONTENT_TYPE, - HeaderValue::from_static("application/query+json"), - ); - - Self { - driver, - operation_factory: Box::new(operation_factory), - query, - query_body: None, - base_options, - base_headers, - session_token, - continuation: None, - complete: false, - phantom: std::marker::PhantomData, - } - } - - /// Consumes the executor and converts it into a stream of pages. - pub fn into_stream(self) -> azure_core::Result> { - Ok(crate::FeedItemIterator::new(futures::stream::try_unfold( - self, - |mut state| async move { - let val = state.next_page().await?; - Ok(val.map(|item| (item, state))) - }, - ))) - } - - /// Fetches the next page of query results. - /// - /// Returns `None` if there are no more pages to fetch. - pub async fn next_page(&mut self) -> azure_core::Result>> { - if self.complete { - return Ok(None); - } - - // Build a fresh operation for this page - let mut operation = (self.operation_factory)(); - - // Serialize the query body on the first page and cache it for subsequent pages. - if self.query_body.is_none() { - self.query_body = Some(serde_json::to_vec(&self.query)?); - } - operation = operation.with_body(self.query_body.clone().unwrap()); - - // The explicit session token serves as an initial hint; the driver's - // internal session manager captures response tokens and applies them - // to subsequent requests automatically. - if let Some(session_token) = &self.session_token { - operation = operation.with_session_token(session_token.clone()); - } - - // Clone the pre-built static headers and add the continuation token - // (the only header that changes between pages). - let mut headers = self.base_headers.clone(); - if let Some(continuation) = &self.continuation { - headers.insert( - constants::CONTINUATION.clone(), - HeaderValue::from(continuation.clone()), - ); - } - - let op_options = self.base_options.clone().with_custom_headers(headers); - - // Execute through the driver - let driver_response = self.driver.execute_operation(operation, op_options).await?; - - // Bridge driver response to SDK types - let cosmos_response = - driver_bridge::driver_response_to_cosmos_response::>(driver_response); - - let page = QueryFeedPage::::from_response(cosmos_response).await?; - - match page.continuation() { - Some(token) => self.continuation = Some(token.to_string()), - None => self.complete = true, - } - - Ok(Some(page)) - } -} diff --git a/sdk/cosmos/azure_data_cosmos/src/retry_policies/client_retry_policy.rs b/sdk/cosmos/azure_data_cosmos/src/retry_policies/client_retry_policy.rs new file mode 100644 index 00000000000..9d77656b7ea --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos/src/retry_policies/client_retry_policy.rs @@ -0,0 +1,1798 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +use super::{ + get_substatus_code_from_error, get_substatus_code_from_response, is_non_retryable_status_code, + resource_throttle_retry_policy::ResourceThrottleRetryPolicy, RequestSentExt, RequestSentStatus, + RetryResult, +}; +use crate::constants::{self, SubStatusCode}; +use crate::cosmos_request::CosmosRequest; +use crate::operation_context::OperationType; +use crate::regions::Region; +use crate::routing::global_endpoint_manager::GlobalEndpointManager; +use crate::routing::global_partition_endpoint_manager::GlobalPartitionEndpointManager; +use azure_core::error::ErrorKind; +use azure_core::http::{RawResponse, StatusCode}; +use azure_core::time::Duration; +use std::sync::Arc; +use tracing::error; +use url::Url; + +/// An integer indicating the default retry intervals between two retry attempts. +const RETRY_INTERVAL_MS: i64 = 1000; + +/// An integer indicating the maximum retry count on endpoint failures. +const MAX_RETRY_COUNT_ON_ENDPOINT_FAILURE: usize = 120; + +/// An integer indicating the maximum retry count on connection failures before marking +/// the endpoint unavailable. +const MAX_RETRY_COUNT_ON_CONNECTION_FAILURE: usize = 3; + +/// Context information for routing retry attempts to specific endpoints. +#[derive(Clone, Debug)] +struct RetryContext { + /// Index of the location to route the retry request to + retry_location_index: usize, + + /// Whether to retry on preferred locations only (true) or all available locations (false) + retry_request_on_preferred_locations: bool, + + /// Whether to route directly to the hub endpoint instead of using location-based routing + route_to_hub: bool, +} + +/// Retry policy for handling data plane request failures. +#[derive(Debug)] +pub(crate) struct ClientRetryPolicy { + /// Manages multi-region endpoint routing and failover logic + global_endpoint_manager: Arc, + + /// An instance of GlobalPartitionEndpointManager that manages partition key range to endpoint mapping + partition_key_range_location_cache: Arc, + + /// Whether automatic endpoint discovery is enabled for failover scenarios + enable_endpoint_discovery: bool, + + /// Counter tracking the number of endpoint failover retry attempts + failover_retry_count: usize, + + /// Counter tracking the number of session token unavailability retry attempts + session_token_retry_count: usize, + + /// Counter tracking the number of service unavailable (503) retry attempts + service_unavailable_retry_count: usize, + + /// Counter tracking the number of consecutive connection failure retry attempts + /// on the current endpoint before marking it unavailable. + connection_retry_count: usize, + + /// Whether the current request is a read operation (true) or write operation (false) + operation_type: Option, + + /// The Cosmos request being processed by the retry policy. + request: Option, + + /// Whether the account supports writing to multiple locations simultaneously + can_use_multiple_write_locations: bool, + + /// The resolved endpoint URL for the current or next request attempt + location_endpoint: Option, + + /// Context information for routing the next retry attempt to a specific location + retry_context: Option, + + /// Regions excluded from routing for the current request + excluded_regions: Option>, + + /// Underlying policy for handling resource throttling (429) with exponential backoff + throttling_retry: ResourceThrottleRetryPolicy, +} + +impl ClientRetryPolicy { + /// Creates a new ClientRetryPolicy instance. + /// + /// # Summary + /// Initializes a retry policy that handles various failure scenarios including session token + /// mismatches, endpoint failures, service unavailability, and resource throttling. The policy + /// manages automatic endpoint discovery, multi-region failover, and coordinates with the + /// GlobalEndpointManager for routing decisions. It wraps a ResourceThrottleRetryPolicy for + /// handling 429 (TooManyRequests) responses with exponential backoff. + /// + /// # Arguments + /// * `global_endpoint_manager` - The endpoint manager for handling multi-region routing and failover + /// + /// # Returns + /// A new `ClientRetryPolicy` instance configured with default retry limits and throttling behavior + pub fn new( + global_endpoint_manager: Arc, + partition_key_range_location_cache: Arc, + excluded_regions: Option>, + ) -> Self { + Self { + global_endpoint_manager, + partition_key_range_location_cache, + enable_endpoint_discovery: true, + failover_retry_count: 0, + session_token_retry_count: 0, + service_unavailable_retry_count: 0, + connection_retry_count: 0, + operation_type: None, + request: None, + can_use_multiple_write_locations: false, + location_endpoint: None, + retry_context: None, + excluded_regions, + throttling_retry: ResourceThrottleRetryPolicy::new(5, 200, 10), + } + } + + /// Returns whether the current operation is read-only. + /// + /// Defaults to `true` if the operation type has not been set, which is the + /// conservative choice: reads are always safe to retry. + fn is_read_only(&self) -> bool { + self.operation_type.is_none_or(|op| op.is_read_only()) + } + + /// Prepares a request before it is sent, configuring routing and endpoint selection. + /// + /// # Summary + /// Performs pre-flight setup for each request attempt by refreshing location cache, + /// determining request characteristics (read vs write, multi-master support), and + /// resolving the target endpoint based on retry context. Handles location-based routing + /// directives, including retry attempts that target specific location indices or the hub + /// endpoint. Clears previous routing context and configures the request with the + /// appropriate endpoint URL for the current attempt. + /// + /// # Arguments + /// * `request` - The mutable request to configure before sending + pub(crate) async fn before_send_request(&mut self, request: &mut CosmosRequest) { + // Ideally, any request flow should not be blocked by the outcome of refresh_location. + // There can be three possible cases: + // a) The refresh_location succeeds when TTL expires. + // b) The refresh_location is bypassed when TTL hasn't expired. + // c) The refresh_location operation has failed. In the event of a failure, + // the error is logged and the request should not be blocked. + // Hence, the outcome of the operation is ignored here. + _ = self.global_endpoint_manager.refresh_location(false).await; + self.operation_type = Some(request.operation_type); + self.excluded_regions = request.excluded_regions.clone().map(|e| e.0); + self.can_use_multiple_write_locations = self + .global_endpoint_manager + .can_use_multiple_write_locations(request); + + if self.can_use_multiple_write_locations { + request + .headers + .insert(constants::ALLOW_TENTATIVE_WRITES, "true"); + } else { + request.headers.remove(constants::ALLOW_TENTATIVE_WRITES); + } + + // Clear previous location-based routing directive + request.request_context.clear_route_to_location(); + + if let Some(ref ctx) = self.retry_context { + let mut req_ctx = request.request_context.clone(); + if ctx.route_to_hub { + req_ctx.route_to_location_endpoint( + request + .resource_link + .url(self.global_endpoint_manager.hub_uri()), + ); + } else { + req_ctx.route_to_location_index( + ctx.retry_location_index, + ctx.retry_request_on_preferred_locations, + ); + } + request.request_context = req_ctx; + } + + // Resolve the endpoint for the request + self.location_endpoint = Some( + self.global_endpoint_manager + .resolve_service_endpoint(request), + ); + + tracing::trace!( + ?self.location_endpoint, + "routed request to endpoint" + ); + + if let Some(ref endpoint) = self.location_endpoint { + request + .request_context + .route_to_location_endpoint(endpoint.clone()); + } + + if self + .partition_key_range_location_cache + .partition_level_failover_enabled() + && request.resource_type.is_partitioned() + { + self.partition_key_range_location_cache + .try_add_partition_level_location_override(request); + } + + self.request = Some(request.clone()); + } + + /// Determines whether a Data Plane request should be retried based on the response or error + /// + /// # Summary + /// Evaluates the result of a request attempt to determine if it should be retried. + /// Distinguishes between successful responses (2xx), client/server error responses + /// (4xx/5xx), and transport/network errors. Delegates error responses to + /// `should_retry_response` and exceptions to `should_retry_error` for detailed + /// evaluation. Non-error responses (2xx, 3xx) are not retried. This method is + /// called by the retry framework after each request attempt. + /// + /// # Arguments + /// * `response` - The result of the request attempt (Ok with response or Err with error) + /// + /// # Returns + /// A `RetryResult`: + /// - `Retry { after: Duration }` if the request should be retried with specified delay + /// - `DoNotRetry` for successful responses or non-retryable failures + pub(crate) async fn should_retry( + &mut self, + response: &azure_core::Result, + ) -> RetryResult { + match response { + Ok(resp) if resp.status().is_server_error() || resp.status().is_client_error() => { + self.should_retry_response(resp).await + } + Ok(_) => RetryResult::DoNotRetry, + Err(err) => self.should_retry_error(err).await, + } + } + + /// Determines if a request should be retried when session token is unavailable. + /// + /// # Summary + /// Handles 404.1022 (READ_SESSION_NOT_AVAILABLE) errors by attempting to retry on different + /// endpoints. For multi-write scenarios, tries all available endpoints before giving up. + /// For single-write scenarios, retries once on the primary write location. Increments the + /// session token retry counter and configures retry context for endpoint routing. + /// + /// # Arguments + /// * `cosmos_request` - The original request that failed with session token unavailable + /// + /// # Returns + /// A `RetryResult`: + /// - `Retry { after: Duration::ZERO }` if retry is allowed on a different endpoint + /// - `DoNotRetry` if endpoint discovery is disabled or all endpoints have been tried + fn should_retry_on_session_not_available(&mut self) -> RetryResult { + self.session_token_retry_count += 1; + + // If endpoint discovery is disabled, the request cannot be retried anywhere else + if !self.enable_endpoint_discovery { + return RetryResult::DoNotRetry; + } + + if self.can_use_multiple_write_locations { + let endpoints = self.global_endpoint_manager.applicable_endpoints( + self.operation_type.unwrap_or(OperationType::Read), + self.excluded_regions.as_ref(), + ); + if self.session_token_retry_count > endpoints.len() { + // When use multiple write locations is true and the request has been tried on all locations, then don't retry the request. + RetryResult::DoNotRetry + } else { + self.retry_context = Some(RetryContext { + retry_location_index: self.session_token_retry_count, + retry_request_on_preferred_locations: true, + route_to_hub: false, + }); + + RetryResult::Retry { + after: Duration::ZERO, + } + } + } else if self.session_token_retry_count > 1 { + // When cannot use multiple write locations, then don't retry the request if + // we have already tried this request on the write location + RetryResult::DoNotRetry + } else { + self.retry_context = Some(RetryContext { + retry_location_index: 0, + retry_request_on_preferred_locations: false, + route_to_hub: false, + }); + + RetryResult::Retry { + after: Duration::ZERO, + } + } + } + + /// Determines if a request should be retried after a connection failure. + /// + /// Connection failures mean the request was never sent to the server, so both + /// reads and writes are safe to retry. The strategy is: + /// + /// 1. Retry up to [`MAX_RETRY_COUNT_ON_CONNECTION_FAILURE`] times on the same + /// endpoint with a delay — the failure may be transient. + /// 2. After exhausting local retries, mark the endpoint unavailable for both + /// reads and writes, refresh the location cache, and fail over to the next + /// available endpoint. + async fn should_retry_on_connection_failure(&mut self) -> RetryResult { + self.connection_retry_count += 1; + + if self.connection_retry_count <= MAX_RETRY_COUNT_ON_CONNECTION_FAILURE { + // Retry on the same endpoint — the connection failure may be transient. + return RetryResult::Retry { + after: Duration::milliseconds(RETRY_INTERVAL_MS), + }; + } + + // Exhausted local retries — mark endpoint unavailable and fail over. + if let Some(ref endpoint) = self.location_endpoint { + self.global_endpoint_manager + .mark_endpoint_unavailable_for_read(endpoint); + self.global_endpoint_manager + .mark_endpoint_unavailable_for_write(endpoint); + } + + self.failover_retry_count += 1; + if self.failover_retry_count > MAX_RETRY_COUNT_ON_ENDPOINT_FAILURE + || !self.enable_endpoint_discovery + { + return RetryResult::DoNotRetry; + } + + _ = self.global_endpoint_manager.refresh_location(true).await; + + // Reset connection retry counter for the new endpoint. + self.connection_retry_count = 0; + + self.retry_context = Some(RetryContext { + retry_location_index: 0, + retry_request_on_preferred_locations: true, + route_to_hub: false, + }); + + RetryResult::Retry { + after: Duration::ZERO, + } + } + + /// Determines if a request should be retried when an endpoint fails. + /// + /// # Summary + /// Handles endpoint failures by marking failed endpoints as unavailable and attempting retry + /// on alternative endpoints. Refreshes the location cache to get updated endpoint information + /// and configures retry delays based on request type (write requests get longer delays). + /// Respects maximum retry limits and endpoint discovery settings. Can mark endpoints as + /// unavailable for reads, writes, or both depending on the failure scenario. + /// + /// # Arguments + /// * `is_read_request` - Whether this is a read operation + /// * `mark_both_read_and_write_as_unavailable` - Whether to mark the endpoint unavailable for both operations + /// * `force_refresh` - Whether to force refresh of the location cache + /// * `retry_on_preferred_locations` - Whether to retry on preferred locations first + /// * `overwrite_endpoint_discovery` - Whether to bypass endpoint discovery checks + /// + /// # Returns + /// A `RetryResult`: + /// - `Retry { after: Duration }` with appropriate delay if retry is allowed + /// - `DoNotRetry` if max retry count exceeded or endpoint discovery disabled + async fn should_retry_on_endpoint_failure( + &mut self, + is_read_request: bool, + mark_both_read_and_write_as_unavailable: bool, + force_refresh: bool, + retry_on_preferred_locations: bool, + overwrite_endpoint_discovery: bool, + ) -> RetryResult { + if self.failover_retry_count > MAX_RETRY_COUNT_ON_ENDPOINT_FAILURE + || (!self.enable_endpoint_discovery && !overwrite_endpoint_discovery) + { + return RetryResult::DoNotRetry; + } + + self.failover_retry_count += 1; + + if let Some(ref endpoint) = self.location_endpoint { + if !overwrite_endpoint_discovery { + if is_read_request || mark_both_read_and_write_as_unavailable { + self.global_endpoint_manager + .mark_endpoint_unavailable_for_read(endpoint); + } + if !is_read_request || mark_both_read_and_write_as_unavailable { + self.global_endpoint_manager + .mark_endpoint_unavailable_for_write(endpoint); + } + } + } + + let retry_delay = if !is_read_request { + if self.failover_retry_count > 1 { + Duration::milliseconds(RETRY_INTERVAL_MS) + } else { + Duration::ZERO + } + } else { + Duration::milliseconds(RETRY_INTERVAL_MS) + }; + + // Ideally, any request flow should not be blocked by the outcome of refresh_location. + // There can be three possible cases: + // a) The refresh_location succeeds when TTL expires. + // b) The refresh_location is bypassed when TTL hasn't expired. + // c) The refresh_location operation has failed. In the event of a failure, + // the error is logged and the request should not be blocked. + // Hence, the outcome of the operation is ignored here. + _ = self + .global_endpoint_manager + .refresh_location(force_refresh) + .await; + let retry_location_index = if retry_on_preferred_locations { + 0 + } else { + self.failover_retry_count + }; + + self.retry_context = Some(RetryContext { + retry_location_index, + retry_request_on_preferred_locations: retry_on_preferred_locations, + route_to_hub: false, + }); + + RetryResult::Retry { after: retry_delay } + } + + /// Determines if a request should be retried for service unavailable status codes. + /// + /// # Summary + /// Handles 503 (ServiceUnavailable), 500 (InternalServerError for reads), and 410 with + /// LeaseNotFound errors by attempting retry on all applicable endpoints (all regions minus + /// excluded regions, in preference of preferred regions). Requires multi-write support for + /// write operations. Configures retry context to route to the next preferred location. + /// + /// # Returns + /// A `RetryResult`: + /// - `Retry { after: Duration::ZERO }` if retry conditions are met + /// - `DoNotRetry` if all endpoints tried or write without multi-write support + fn should_retry_on_unavailable_endpoint_status_codes(&mut self) -> RetryResult { + self.service_unavailable_retry_count += 1; + + if !self.can_use_multiple_write_locations + && !self + .operation_type + .as_ref() + .is_some_and(|op| op.is_read_only()) + { + return RetryResult::DoNotRetry; + } + + // automatic failover support needed to be plugged in. + if !self.can_use_multiple_write_locations + && !self.is_read_only() + && !self + .partition_key_range_location_cache + .partition_level_automatic_failover_enabled() + { + return RetryResult::DoNotRetry; + } + + let endpoints = self + .global_endpoint_manager + .applicable_endpoints(self.operation_type.unwrap(), self.excluded_regions.as_ref()); + + if self.service_unavailable_retry_count > endpoints.len() { + return RetryResult::DoNotRetry; + } + + self.retry_context = Some(RetryContext { + retry_location_index: self.service_unavailable_retry_count, + retry_request_on_preferred_locations: true, + route_to_hub: false, + }); + + RetryResult::Retry { + after: Duration::ZERO, + } + } + + /// Routes HTTP status codes to appropriate retry handling logic. + /// + /// # Summary + /// Evaluates HTTP status code and Cosmos DB sub-status code combinations to determine + /// the appropriate retry strategy. Handles specific scenarios: 403.3 (WriteForbidden) + /// triggers endpoint failover with cache refresh, 404.1022 (READ_SESSION_NOT_AVAILABLE) + /// retries on different endpoints, 503 (ServiceUnavailable) attempts preferred location + /// failover, and 500/408/410 with LeaseNotFound retry on alternative endpoints for reads. + /// + /// For read operations, any status code that is not considered non-retryable by + /// [`is_non_retryable_status_code`] is retried on an alternative endpoint. For write + /// operations, unhandled status codes are delegated to the throttling policy. + /// + /// # Arguments + /// * `status_code` - The HTTP status code from the response + /// * `sub_status_code` - The Cosmos DB specific sub-status code + /// + /// # Returns + /// An `Option`: + /// - `Some(RetryResult)` if the status code requires special retry handling + /// - `None` if the status code should be delegated to the throttling policy + async fn should_retry_on_http_status( + &mut self, + status_code: StatusCode, + sub_status_code: Option, + ) -> Option { + // Forbidden - Write forbidden (403.3) + if status_code == StatusCode::Forbidden + && sub_status_code == Some(SubStatusCode::WRITE_FORBIDDEN) + { + if self.request.is_some() + && (self.is_request_eligible_for_per_partition_automatic_failover() + || self.increment_failure_counter_and_check_circuit_breaker_eligibility()) + && self + .partition_key_range_location_cache + .try_mark_endpoint_unavailable_for_partition_key_range( + self.request.as_ref().unwrap(), + ) + { + return Some(RetryResult::Retry { + after: Duration::ZERO, + }); + } + + return Some( + self.should_retry_on_endpoint_failure(false, false, true, false, false) + .await, + ); + } + + // Read Session Not Available (404.1022) + if status_code == StatusCode::NotFound + && sub_status_code == Some(SubStatusCode::READ_SESSION_NOT_AVAILABLE) + { + return Some(self.should_retry_on_session_not_available()); + } + + if self.should_mark_endpoint_unavailable_on_system_resource_unavailable_for_write( + Some(status_code), + sub_status_code, + ) { + error!( + "Operation will NOT be retried on local region. \ + Treating SystemResourceUnavailable (429/3092) as ServiceUnavailable (503). \ + Status code: 429, sub status code: 3092" + ); + + return Some( + self.try_mark_endpoint_unavailable_for_pk_range_and_retry_on_service_unavailable( + true, + ), + ); + } + + // Service unavailable (503) + if status_code == StatusCode::ServiceUnavailable { + return Some( + self.try_mark_endpoint_unavailable_for_pk_range_and_retry_on_service_unavailable( + false, + ), + ); + } + + // Gone - Lease not found (410.1022) applies to both reads and writes + if status_code == StatusCode::Gone + && sub_status_code == Some(SubStatusCode::LEASE_NOT_FOUND) + { + return Some(self.should_retry_on_unavailable_endpoint_status_codes()); + } + + // For read operations, retry on any status code that is not explicitly non-retryable. + // This ensures transient server errors are retried on alternative endpoints. + if self.is_read_only() && !is_non_retryable_status_code(status_code, sub_status_code) { + return Some(self.should_retry_on_unavailable_endpoint_status_codes()); + } + + None + } + + /// Marks endpoint unavailable for partition key range and retries on service unavailable. + fn try_mark_endpoint_unavailable_for_pk_range_and_retry_on_service_unavailable( + &mut self, + is_system_resource_unavailable_for_write: bool, + ) -> RetryResult { + self.try_mark_endpoint_unavailable_for_pk_range(is_system_resource_unavailable_for_write); + self.should_retry_on_unavailable_endpoint_status_codes() + } + + /// Attempts to mark the endpoint unavailable for the partition key range. + fn try_mark_endpoint_unavailable_for_pk_range( + &self, + is_system_resource_unavailable_for_write: bool, + ) -> bool { + if let Some(request) = self.request.as_ref() { + if is_system_resource_unavailable_for_write + || self.is_request_eligible_for_per_partition_automatic_failover() + || self.increment_failure_counter_and_check_circuit_breaker_eligibility() + { + return self + .partition_key_range_location_cache + .try_mark_endpoint_unavailable_for_partition_key_range(request); + } + } + false + } + + /// Checks if endpoint should be marked unavailable on system resource unavailable for write. + fn should_mark_endpoint_unavailable_on_system_resource_unavailable_for_write( + &self, + status_code: Option, + sub_status_code: Option, + ) -> bool { + self.can_use_multiple_write_locations + && status_code == Some(StatusCode::TooManyRequests) + && sub_status_code == Some(SubStatusCode::SYSTEM_RESOURCE_NOT_AVAILABLE) + } + + /// Checks if request is eligible for per-partition automatic failover. + fn is_request_eligible_for_per_partition_automatic_failover(&self) -> bool { + if let Some(request) = self.request.as_ref() { + return self + .partition_key_range_location_cache + .is_request_eligible_for_per_partition_automatic_failover(request); + } + false + } + + /// Increments failure counter and checks if request is eligible for partition-level circuit breaker. + fn increment_failure_counter_and_check_circuit_breaker_eligibility(&self) -> bool { + if let Some(request) = self.request.as_ref() { + return self + .partition_key_range_location_cache + .is_request_eligible_for_partition_level_circuit_breaker(request) + && self + .partition_key_range_location_cache + .increment_request_failure_counter_and_check_if_partition_can_failover( + request, + ); + } + false + } + + /// Evaluates an error to determine if the request should be retried. + /// + /// # Summary + /// First checks the [`RequestSentStatus`] to handle transport-level errors: + /// - `NotSent`: retries reads and writes (request never reached server). + /// - `Sent`/`Unknown` with transport errors (`Timeout`, `Io`): retries reads only. + /// + /// For HTTP-level errors, delegates to `should_retry_on_http_status` for + /// scenario-specific retry logic (403.3, 404.1022, 503, 500, 410), then falls + /// back to the throttling retry policy for 429 (TooManyRequests). + /// + /// # Arguments + /// * `err` - The error that occurred during the request + /// + /// # Returns + /// A `RetryResult` indicating whether to retry and with what delay + async fn should_retry_error(&mut self, err: &azure_core::Error) -> RetryResult { + // Determine whether the request was actually sent to the server. + // This drives the retry decision for transport-level errors: + // - NotSent: safe to retry reads and writes (request never reached server) + // - Sent/Unknown: only retry reads (write may have been applied) + match err.request_sent_status() { + RequestSentStatus::NotSent => { + return self.should_retry_on_connection_failure().await; + } + RequestSentStatus::Sent | RequestSentStatus::Unknown => { + if matches!(err.kind(), ErrorKind::Io) { + if self.is_read_only() { + return self.should_retry_on_unavailable_endpoint_status_codes(); + } + return RetryResult::DoNotRetry; + } + } + } + + let status_code = err.http_status().unwrap_or(StatusCode::UnknownValue(0)); + let sub_status_code = get_substatus_code_from_error(err); + + if let Some(result) = self + .should_retry_on_http_status(status_code, sub_status_code) + .await + { + return result; + } + + self.throttling_retry.should_retry_error(err) + } + + /// Evaluates an HTTP response to determine if the request should be retried. + /// + /// # Summary + /// Extracts HTTP status code and sub-status code from the response and delegates to + /// `should_retry_on_http_status` for scenario-specific retry logic. If the response + /// doesn't match any special retry cases (403.3, 404.1022, 503, 500, 410), falls + /// back to the throttling retry policy which handles 429 (TooManyRequests) responses + /// with exponential backoff. + /// + /// # Arguments + /// * `response` - The HTTP response received from the service + /// + /// # Returns + /// A `RetryResult` indicating whether to retry and with what delay + async fn should_retry_response(&mut self, response: &RawResponse) -> RetryResult { + let status_code = response.status(); + let sub_status_code = get_substatus_code_from_response(response); + + if let Some(result) = self + .should_retry_on_http_status(status_code, sub_status_code) + .await + { + return result; + } + + self.throttling_retry.should_retry_response(response) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::models::AccountRegion; + use crate::operation_context::OperationType; + use crate::partition_key::PartitionKey; + use crate::regions::Region; + use crate::resource_context::{ResourceLink, ResourceType}; + use crate::routing::global_endpoint_manager::GlobalEndpointManager; + use crate::routing::partition_key_range::PartitionKeyRange; + use azure_core::http::headers::Headers; + use azure_core::http::ClientOptions; + use azure_core::Bytes; + use std::sync::Arc; + + fn create_test_endpoint_manager() -> Arc { + let pipeline = azure_core::http::Pipeline::new( + option_env!("CARGO_PKG_NAME"), + option_env!("CARGO_PKG_VERSION"), + ClientOptions::default(), + Vec::new(), + Vec::new(), + None, + ); + + GlobalEndpointManager::new( + "https://test.documents.azure.com".parse().unwrap(), + vec![Region::from("West US"), Region::from("East US")], + vec![], + pipeline, + ) + } + + fn create_test_endpoint_manager_no_locations() -> Arc { + let pipeline = azure_core::http::Pipeline::new( + option_env!("CARGO_PKG_NAME"), + option_env!("CARGO_PKG_VERSION"), + ClientOptions::default(), + Vec::new(), + Vec::new(), + None, + ); + + GlobalEndpointManager::new( + "https://test.documents.azure.com".parse().unwrap(), + vec![], + vec![], + pipeline, + ) + } + + fn create_test_endpoint_manager_with_preferred_locations() -> Arc { + let pipeline = azure_core::http::Pipeline::new( + option_env!("CARGO_PKG_NAME"), + option_env!("CARGO_PKG_VERSION"), + ClientOptions::default(), + Vec::new(), + Vec::new(), + None, + ); + + GlobalEndpointManager::new( + "https://test.documents.azure.com".parse().unwrap(), + vec![Region::EAST_ASIA, Region::WEST_US, Region::NORTH_CENTRAL_US], + vec![], + pipeline, + ) + } + + fn create_test_policy() -> ClientRetryPolicy { + let manager = create_test_endpoint_manager(); + let partition_manager = GlobalPartitionEndpointManager::new(manager.clone(), false, false); + ClientRetryPolicy::new(manager, partition_manager, None) + } + + fn create_test_policy_no_locations() -> ClientRetryPolicy { + let manager = create_test_endpoint_manager_no_locations(); + let partition_manager = GlobalPartitionEndpointManager::new(manager.clone(), false, false); + ClientRetryPolicy::new(manager, partition_manager, None) + } + + fn create_test_policy_with_preferred_locations() -> ClientRetryPolicy { + let manager = create_test_endpoint_manager_with_preferred_locations(); + let partition_manager = GlobalPartitionEndpointManager::new(manager.clone(), false, false); + ClientRetryPolicy::new(manager, partition_manager, None) + } + + fn create_test_request() -> CosmosRequest { + let resource_link = ResourceLink::root(ResourceType::Documents); + CosmosRequest::builder(OperationType::Read, resource_link.clone()) + .partition_key(PartitionKey::from("test")) + .build() + .unwrap() + } + + fn create_write_request() -> CosmosRequest { + let resource_link = ResourceLink::root(ResourceType::Documents); + CosmosRequest::builder(OperationType::Create, resource_link.clone()) + .partition_key(PartitionKey::from("test")) + .build() + .unwrap() + } + + fn create_raw_response(status_code: StatusCode) -> RawResponse { + let headers = Headers::new(); + RawResponse::from_bytes(status_code, headers, Bytes::new()) + } + + fn create_raw_response_with_substatus(status_code: StatusCode, substatus: u32) -> RawResponse { + let mut headers = Headers::new(); + headers.insert("x-ms-substatus", substatus.to_string()); + RawResponse::from_bytes(status_code, headers, Bytes::new()) + } + + fn create_error_with_status(status: StatusCode) -> azure_core::Error { + let response = create_raw_response(status); + azure_core::Error::new( + azure_core::error::ErrorKind::HttpResponse { + status: response.status(), + error_code: None, + raw_response: Some(Box::new(response)), + }, + "Test error", + ) + } + + fn create_error_with_substatus(status: StatusCode, substatus: u32) -> azure_core::Error { + let response = create_raw_response_with_substatus(status, substatus); + azure_core::Error::new( + azure_core::error::ErrorKind::HttpResponse { + status: response.status(), + error_code: None, + raw_response: Some(Box::new(response)), + }, + "Test error with substatus", + ) + } + + /// Creates a multi-master endpoint manager with two regions (West US + East US), + /// both configured as write AND read endpoints. + fn create_multi_master_endpoint_manager() -> Arc { + let pipeline = azure_core::http::Pipeline::new( + option_env!("CARGO_PKG_NAME"), + option_env!("CARGO_PKG_VERSION"), + ClientOptions::default(), + Vec::new(), + Vec::new(), + None, + ); + + let manager = GlobalEndpointManager::new( + "https://test.documents.azure.com".parse().unwrap(), + vec![Region::from("West US"), Region::from("East US")], + vec![], + pipeline, + ); + + let west = AccountRegion { + name: Region::from("West US"), + database_account_endpoint: "https://test-westus.documents.azure.com".parse().unwrap(), + }; + let east = AccountRegion { + name: Region::from("East US"), + database_account_endpoint: "https://test-eastus.documents.azure.com".parse().unwrap(), + }; + + // Both regions as write + read → multi-master account + manager.update_location_cache(vec![west.clone(), east.clone()], vec![west, east]); + manager + } + + /// Creates a policy configured for multi-master with PPCB enabled. + fn create_multi_master_ppcb_policy() -> ClientRetryPolicy { + let manager = create_multi_master_endpoint_manager(); + // PPAF = false, PPCB = true + let partition_manager = GlobalPartitionEndpointManager::new(manager.clone(), false, true); + ClientRetryPolicy::new(manager, partition_manager, None) + } + + /// Creates a write request with a PartitionKeyRange and routed endpoint, + /// as the PPCB code path requires both to be present. + fn create_ppcb_write_request() -> CosmosRequest { + let resource_link = ResourceLink::root(ResourceType::Documents); + let mut request = CosmosRequest::builder(OperationType::Create, resource_link) + .partition_key(PartitionKey::from("test")) + .build() + .unwrap(); + request.request_context.resolved_partition_key_range = + Some(PartitionKeyRange::new("0".into(), "".into(), "FF".into())); + request.request_context.resolved_collection_rid = Some("dbs/db1/colls/coll1".into()); + request.request_context.location_endpoint_to_route = + Some("https://test-westus.documents.azure.com/".parse().unwrap()); + request + } + + #[tokio::test] + async fn test_new_policy_initialization() { + let policy = create_test_policy(); + assert!(policy.enable_endpoint_discovery); + assert_eq!(policy.failover_retry_count, 0); + assert_eq!(policy.session_token_retry_count, 0); + assert_eq!(policy.service_unavailable_retry_count, 0); + assert!(!policy.can_use_multiple_write_locations); + assert!(policy.location_endpoint.is_none()); + assert!(policy.retry_context.is_none()); + assert!(policy.operation_type.is_none()); + } + + #[tokio::test] + async fn test_retry_context_none_initially() { + let policy = create_test_policy(); + assert!(policy.retry_context.is_none()); + } + + #[tokio::test] + async fn test_should_retry_service_unavailable_with_preferred_locations() { + let mut policy = create_test_policy_with_preferred_locations(); + policy.operation_type = Some(OperationType::Read); + let error = create_error_with_status(StatusCode::ServiceUnavailable); + + let result = policy.should_retry_error(&error).await; + + match result { + RetryResult::Retry { after } => { + assert_eq!(after, Duration::ZERO); + assert_eq!(policy.service_unavailable_retry_count, 1); + assert!(policy.retry_context.is_some()); + } + _ => panic!("Expected retry for ServiceUnavailable with preferred locations"), + } + } + + #[tokio::test] + async fn test_should_retry_service_unavailable_without_preferred_locations() { + // Even with no preferred locations, applicable_endpoints returns the default endpoint + let mut policy = create_test_policy_no_locations(); + policy.operation_type = Some(OperationType::Read); + let error = create_error_with_status(StatusCode::ServiceUnavailable); + + let result = policy.should_retry_error(&error).await; + + match result { + RetryResult::Retry { after } => { + assert_eq!(after, Duration::ZERO); + assert_eq!(policy.service_unavailable_retry_count, 1); + } + _ => panic!("Expected retry for ServiceUnavailable (default endpoint available)"), + } + + // Second attempt should stop — only the default endpoint was available + let result = policy.should_retry_error(&error).await; + assert!( + !result.is_retry(), + "Expected DoNotRetry after exhausting the single default endpoint" + ); + } + + #[tokio::test] + async fn test_should_retry_internal_server_error_for_read() { + let mut policy = create_test_policy_with_preferred_locations(); + policy.operation_type = Some(OperationType::Read); + let error = create_error_with_status(StatusCode::InternalServerError); + + let result = policy.should_retry_error(&error).await; + + match result { + RetryResult::Retry { after } => { + assert_eq!(after, Duration::ZERO); + assert_eq!(policy.service_unavailable_retry_count, 1); + } + _ => panic!("Expected retry for InternalServerError on read request"), + } + } + + #[tokio::test] + async fn test_should_not_retry_internal_server_error_for_write() { + let mut policy = create_test_policy_with_preferred_locations(); + policy.operation_type = Some(OperationType::Create); + let error = create_error_with_status(StatusCode::InternalServerError); + + let result = policy.should_retry_error(&error).await; + + match result { + RetryResult::DoNotRetry => {} + _ => panic!("Expected DoNotRetry for InternalServerError on write request"), + } + } + + #[tokio::test] + async fn test_should_retry_gone_with_lease_not_found() { + let mut policy = create_test_policy_with_preferred_locations(); + policy.operation_type = Some(OperationType::Read); + let error = create_error_with_substatus( + StatusCode::Gone, + SubStatusCode::LEASE_NOT_FOUND.value() as u32, + ); + + let result = policy.should_retry_error(&error).await; + + match result { + RetryResult::Retry { after } => { + assert_eq!(after, Duration::ZERO); + assert_eq!(policy.service_unavailable_retry_count, 1); + } + _ => panic!("Expected retry for Gone with LeaseNotFound"), + } + } + + #[tokio::test] + async fn test_should_retry_gone_with_lease_not_found_for_write() { + let mut policy = create_test_policy_with_preferred_locations(); + policy.operation_type = Some(OperationType::Create); + policy.can_use_multiple_write_locations = true; + let error = create_error_with_substatus( + StatusCode::Gone, + SubStatusCode::LEASE_NOT_FOUND.value() as u32, + ); + + let result = policy.should_retry_error(&error).await; + + match result { + RetryResult::Retry { after } => { + assert_eq!(after, Duration::ZERO); + assert_eq!(policy.service_unavailable_retry_count, 1); + } + _ => panic!("Expected retry for Gone with LeaseNotFound on write"), + } + } + + #[tokio::test] + async fn test_should_retry_write_forbidden() { + let mut policy = create_test_policy(); + policy.operation_type = Some(OperationType::Create); + policy.location_endpoint = Some("https://test.documents.azure.com".parse().unwrap()); + let error = create_error_with_substatus( + StatusCode::Forbidden, + SubStatusCode::WRITE_FORBIDDEN.value() as u32, + ); + + let result = policy.should_retry_error(&error).await; + + match result { + RetryResult::Retry { after: _ } => { + assert_eq!(policy.failover_retry_count, 1); + } + _ => panic!("Expected retry for WriteForbidden"), + } + } + + #[tokio::test] + async fn test_should_retry_session_not_available_single_write() { + let mut policy = create_test_policy(); + policy.enable_endpoint_discovery = true; + policy.can_use_multiple_write_locations = false; + + let error = create_error_with_substatus( + StatusCode::NotFound, + SubStatusCode::READ_SESSION_NOT_AVAILABLE.value() as u32, + ); + + let result = policy.should_retry_error(&error).await; + + match result { + RetryResult::Retry { after } => { + assert_eq!(after, Duration::ZERO); + assert_eq!(policy.session_token_retry_count, 1); + assert!(policy.retry_context.is_some()); + } + _ => panic!("Expected retry for READ_SESSION_NOT_AVAILABLE"), + } + } + + #[tokio::test] + async fn test_should_not_retry_session_not_available_when_discovery_disabled() { + let mut policy = create_test_policy(); + policy.enable_endpoint_discovery = false; + + let error = create_error_with_substatus( + StatusCode::NotFound, + SubStatusCode::READ_SESSION_NOT_AVAILABLE.value() as u32, + ); + + let result = policy.should_retry_error(&error).await; + + match result { + RetryResult::DoNotRetry => { + assert_eq!(policy.session_token_retry_count, 1); + } + _ => panic!("Expected DoNotRetry when endpoint discovery disabled"), + } + } + + #[tokio::test] + async fn test_should_not_retry_session_not_available_after_all_endpoints_tried() { + let mut policy = create_test_policy(); + policy.enable_endpoint_discovery = true; + policy.can_use_multiple_write_locations = false; + policy.operation_type = Some(OperationType::Read); + // create_test_policy has 2 preferred locations, so set count to 2 + // to simulate all endpoints already tried + policy.session_token_retry_count = 2; + + let error = create_error_with_substatus( + StatusCode::NotFound, + SubStatusCode::READ_SESSION_NOT_AVAILABLE.value() as u32, + ); + + let result = policy.should_retry_error(&error).await; + match result { + RetryResult::DoNotRetry => { + assert_eq!(policy.session_token_retry_count, 3); + } + _ => panic!("Expected DoNotRetry after all endpoints tried"), + } + } + + #[tokio::test] + async fn test_should_not_retry_service_unavailable_after_all_endpoints_tried() { + let mut policy = create_test_policy_with_preferred_locations(); + policy.operation_type = Some(OperationType::Read); + // applicable_endpoints returns 1 (default endpoint) in test setup, + // so set count to 1 to simulate exhaustion + policy.service_unavailable_retry_count = 1; + + let error = create_error_with_status(StatusCode::ServiceUnavailable); + + let result = policy.should_retry_error(&error).await; + + match result { + RetryResult::DoNotRetry => { + assert_eq!(policy.service_unavailable_retry_count, 2); + } + _ => panic!("Expected DoNotRetry after all endpoints tried"), + } + } + + #[tokio::test] + async fn test_should_not_retry_service_unavailable_for_write_without_multi_write() { + let mut policy = create_test_policy_with_preferred_locations(); + policy.operation_type = Some(OperationType::Create); + policy.can_use_multiple_write_locations = false; + + let error = create_error_with_status(StatusCode::ServiceUnavailable); + + let result = policy.should_retry_error(&error).await; + + match result { + RetryResult::DoNotRetry => {} + _ => panic!("Expected DoNotRetry for write without multi-write support"), + } + } + + #[tokio::test] + async fn test_should_retry_too_many_requests() { + let mut policy = create_test_policy(); + let error = create_error_with_status(StatusCode::TooManyRequests); + + let result = policy.should_retry_error(&error).await; + + // TooManyRequests should be delegated to throttling policy + match result { + RetryResult::Retry { after: _ } => {} + _ => panic!("Expected retry for TooManyRequests (throttling)"), + } + } + + #[tokio::test] + async fn test_should_retry_response_service_unavailable() { + let mut policy = create_test_policy_with_preferred_locations(); + policy.operation_type = Some(OperationType::Read); + let response = create_raw_response(StatusCode::ServiceUnavailable); + + let result = policy.should_retry_response(&response).await; + + match result { + RetryResult::Retry { after } => { + assert_eq!(after, Duration::ZERO); + assert_eq!(policy.service_unavailable_retry_count, 1); + } + _ => panic!("Expected retry for ServiceUnavailable response"), + } + } + + #[tokio::test] + async fn test_should_retry_response_too_many_requests() { + let mut policy = create_test_policy(); + let response = create_raw_response(StatusCode::TooManyRequests); + + let result = policy.should_retry_response(&response).await; + + // Should be delegated to throttling policy + match result { + RetryResult::Retry { after: _ } => {} + _ => panic!("Expected retry for TooManyRequests response"), + } + } + + #[tokio::test] + async fn test_should_retry_for_error_response() { + let mut policy = create_test_policy_with_preferred_locations(); + policy.operation_type = Some(OperationType::Read); + let response = create_raw_response(StatusCode::ServiceUnavailable); + let result_with_response: azure_core::Result = Ok(response); + + let retry_result = policy.should_retry(&result_with_response).await; + + match retry_result { + RetryResult::Retry { after } => { + assert_eq!(after, Duration::ZERO); + } + _ => panic!("Expected retry for error response"), + } + } + + #[tokio::test] + async fn test_should_not_retry_for_success_response() { + let mut policy = create_test_policy(); + let response = create_raw_response(StatusCode::Ok); + let result_with_response: azure_core::Result = Ok(response); + + let retry_result = policy.should_retry(&result_with_response).await; + + match retry_result { + RetryResult::DoNotRetry => {} + _ => panic!("Expected DoNotRetry for success response"), + } + } + + #[tokio::test] + async fn test_should_retry_for_transport_error() { + let mut policy = create_test_policy_with_preferred_locations(); + policy.operation_type = Some(OperationType::Read); + let error = create_error_with_status(StatusCode::ServiceUnavailable); + let result_with_error: azure_core::Result = Err(error); + + let retry_result = policy.should_retry(&result_with_error).await; + + match retry_result { + RetryResult::Retry { after: _ } => {} + _ => panic!("Expected retry for transport error"), + } + } + + #[tokio::test] + async fn test_endpoint_failover_increments_count() { + let mut policy = create_test_policy(); + policy.location_endpoint = Some("https://test.documents.azure.com".parse().unwrap()); + + let result = policy + .should_retry_on_endpoint_failure(true, false, false, false, false) + .await; + + match result { + RetryResult::Retry { after: _ } => { + assert_eq!(policy.failover_retry_count, 1); + assert!(policy.retry_context.is_some()); + } + _ => panic!("Expected retry for endpoint failure"), + } + } + + #[tokio::test] + async fn test_endpoint_failover_respects_max_retry_count() { + let mut policy = create_test_policy(); + policy.failover_retry_count = MAX_RETRY_COUNT_ON_ENDPOINT_FAILURE + 1; + + let result = policy + .should_retry_on_endpoint_failure(true, false, false, false, false) + .await; + + match result { + RetryResult::DoNotRetry => {} + _ => panic!("Expected DoNotRetry after max failover retries"), + } + } + + #[tokio::test] + async fn test_endpoint_failover_respects_endpoint_discovery_disabled() { + let mut policy = create_test_policy(); + policy.enable_endpoint_discovery = false; + + let result = policy + .should_retry_on_endpoint_failure(true, false, false, false, false) + .await; + + match result { + RetryResult::DoNotRetry => {} + _ => panic!("Expected DoNotRetry when endpoint discovery disabled"), + } + } + + #[tokio::test] + async fn test_endpoint_failover_with_overwrite_discovery() { + let mut policy = create_test_policy(); + policy.enable_endpoint_discovery = false; + policy.location_endpoint = Some("https://test.documents.azure.com".parse().unwrap()); + + let result = policy + .should_retry_on_endpoint_failure(true, false, false, false, true) + .await; + + match result { + RetryResult::Retry { after: _ } => { + assert_eq!(policy.failover_retry_count, 1); + } + _ => panic!("Expected retry when overwrite_endpoint_discovery is true"), + } + } + + #[tokio::test] + async fn test_endpoint_failover_write_delay() { + let mut policy = create_test_policy(); + policy.location_endpoint = Some("https://test.documents.azure.com".parse().unwrap()); + policy.failover_retry_count = 1; + + let result = policy + .should_retry_on_endpoint_failure(false, false, false, false, false) + .await; + + match result { + RetryResult::Retry { after } => { + assert_eq!(after, Duration::milliseconds(RETRY_INTERVAL_MS)); + assert_eq!(policy.failover_retry_count, 2); + } + _ => panic!("Expected retry with delay for write request"), + } + } + + #[tokio::test] + async fn test_endpoint_failover_first_write_no_delay() { + let mut policy = create_test_policy(); + policy.location_endpoint = Some("https://test.documents.azure.com".parse().unwrap()); + + let result = policy + .should_retry_on_endpoint_failure(false, false, false, false, false) + .await; + + match result { + RetryResult::Retry { after } => { + assert_eq!(after, Duration::ZERO); + assert_eq!(policy.failover_retry_count, 1); + } + _ => panic!("Expected retry with zero delay for first write failover"), + } + } + + #[tokio::test] + async fn test_endpoint_failover_read_always_has_delay() { + let mut policy = create_test_policy(); + policy.location_endpoint = Some("https://test.documents.azure.com".parse().unwrap()); + + let result = policy + .should_retry_on_endpoint_failure(true, false, false, false, false) + .await; + + match result { + RetryResult::Retry { after } => { + assert_eq!(after, Duration::milliseconds(RETRY_INTERVAL_MS)); + } + _ => panic!("Expected retry with delay for read request"), + } + } + + #[tokio::test] + async fn test_before_send_request_sets_read_flag() { + let mut policy = create_test_policy(); + let mut request = create_test_request(); + + policy.before_send_request(&mut request).await; + + assert!(policy.operation_type.is_some()); + assert!(policy.operation_type.unwrap().is_read_only()); + } + + #[tokio::test] + async fn test_before_send_request_sets_write_flag() { + let mut policy = create_test_policy(); + let mut request = create_write_request(); + + policy.before_send_request(&mut request).await; + + assert!(policy.operation_type.is_some()); + assert!(!policy.operation_type.unwrap().is_read_only()); + } + + #[tokio::test] + async fn test_retry_context_applied_to_request() { + let mut policy = create_test_policy(); + policy.retry_context = Some(RetryContext { + retry_location_index: 1, + retry_request_on_preferred_locations: true, + route_to_hub: false, + }); + let mut request = create_test_request(); + + policy.before_send_request(&mut request).await; + + // The retry context should be applied to the request + assert!(policy.location_endpoint.is_some()); + } + + #[test] + fn test_retry_context_creation() { + let ctx = RetryContext { + retry_location_index: 2, + retry_request_on_preferred_locations: true, + route_to_hub: false, + }; + + assert_eq!(ctx.retry_location_index, 2); + assert!(ctx.retry_request_on_preferred_locations); + assert!(!ctx.route_to_hub); + } + + #[test] + fn test_constants_values() { + assert_eq!(RETRY_INTERVAL_MS, 1000); + assert_eq!(MAX_RETRY_COUNT_ON_ENDPOINT_FAILURE, 120); + } + + #[tokio::test] + async fn read_retries_on_unknown_server_error() { + let mut policy = create_test_policy_with_preferred_locations(); + policy.operation_type = Some(OperationType::Read); + + // A non-specific server error (e.g., 502 BadGateway) should be retried for reads + let error = create_error_with_status(StatusCode::BadGateway); + let result = policy.should_retry_error(&error).await; + + assert!( + result.is_retry(), + "Expected retry for BadGateway on read request" + ); + } + + #[tokio::test] + async fn read_does_not_retry_non_retryable_status_codes() { + for status in [ + StatusCode::BadRequest, + StatusCode::Unauthorized, + StatusCode::NotFound, + StatusCode::MethodNotAllowed, + StatusCode::Conflict, + StatusCode::PreconditionFailed, + StatusCode::PayloadTooLarge, + StatusCode::Locked, + constants::RETRY_WITH, + ] { + let mut policy = create_test_policy_with_preferred_locations(); + policy.operation_type = Some(OperationType::Read); + + let error = create_error_with_status(status); + let result = policy.should_retry_error(&error).await; + + assert!( + !result.is_retry(), + "Expected DoNotRetry for {status:?} on read request" + ); + } + } + + #[tokio::test] + async fn write_does_not_retry_unknown_server_error() { + let mut policy = create_test_policy_with_preferred_locations(); + policy.operation_type = Some(OperationType::Create); + + // A non-specific server error should NOT be retried for writes + let error = create_error_with_status(StatusCode::BadGateway); + let result = policy.should_retry_error(&error).await; + + assert!( + !result.is_retry(), + "Expected DoNotRetry for BadGateway on write request" + ); + } + + #[tokio::test] + async fn read_retries_on_forbidden_without_write_forbidden_substatus() { + let mut policy = create_test_policy_with_preferred_locations(); + policy.operation_type = Some(OperationType::Read); + + // Forbidden without WRITE_FORBIDDEN substatus should be retried for reads + let error = create_error_with_status(StatusCode::Forbidden); + let result = policy.should_retry_error(&error).await; + + assert!( + result.is_retry(), + "Expected retry for Forbidden (no substatus) on read request" + ); + } + + fn create_connection_error(message: &str) -> azure_core::Error { + azure_core::Error::with_message( + azure_core::error::ErrorKind::Connection, + message.to_string(), + ) + } + + fn create_timeout_error(message: &str) -> azure_core::Error { + azure_core::Error::with_message(azure_core::error::ErrorKind::Io, message.to_string()) + } + + fn create_io_error(message: &str) -> azure_core::Error { + azure_core::Error::with_message(azure_core::error::ErrorKind::Io, message.to_string()) + } + + #[tokio::test] + async fn connection_error_retries_read() { + let mut policy = create_test_policy(); + let mut request = create_test_request(); + policy.before_send_request(&mut request).await; + + let err = create_connection_error("connection refused"); + let result = policy.should_retry(&Err(err)).await; + assert!( + result.is_retry(), + "connection error should retry read requests" + ); + } + + #[tokio::test] + async fn connection_error_retries_write() { + let mut policy = create_test_policy(); + let mut request = create_write_request(); + policy.before_send_request(&mut request).await; + + let err = create_connection_error("connection refused"); + let result = policy.should_retry(&Err(err)).await; + assert!( + result.is_retry(), + "connection error should retry write requests" + ); + } + + #[tokio::test] + async fn connection_error_retries_on_same_endpoint() { + let mut policy = create_test_policy(); + let mut request = create_test_request(); + policy.before_send_request(&mut request).await; + + // First 3 connection errors should retry on the same endpoint. + for i in 1..=3 { + let err = create_connection_error("connection refused"); + let result = policy.should_retry(&Err(err)).await; + assert!(result.is_retry(), "connection attempt {i} should retry"); + assert_eq!(policy.connection_retry_count, i); + assert_eq!( + policy.failover_retry_count, 0, + "should not failover during local retries" + ); + } + } + + #[tokio::test] + async fn connection_error_fails_over_after_max_retries() { + let mut policy = create_test_policy(); + let mut request = create_test_request(); + policy.before_send_request(&mut request).await; + + // Exhaust local retries. + for _ in 0..3 { + let err = create_connection_error("connection refused"); + policy.should_retry(&Err(err)).await; + } + + // Next connection error should trigger failover. + let err = create_connection_error("connection refused"); + let result = policy.should_retry(&Err(err)).await; + assert!(result.is_retry(), "should failover to next endpoint"); + assert_eq!( + policy.failover_retry_count, 1, + "failover_retry_count should increment after local retries exhausted" + ); + assert_eq!( + policy.connection_retry_count, 0, + "connection_retry_count should reset for new endpoint" + ); + } + + #[tokio::test] + async fn response_timeout_retries_read() { + let mut policy = create_test_policy(); + let mut request = create_test_request(); + policy.before_send_request(&mut request).await; + + let err = create_timeout_error("response timeout"); + let result = policy.should_retry(&Err(err)).await; + assert!( + result.is_retry(), + "response timeout should retry read requests" + ); + } + + #[tokio::test] + async fn response_timeout_does_not_retry_write() { + let mut policy = create_test_policy(); + let mut request = create_write_request(); + policy.before_send_request(&mut request).await; + + let err = create_timeout_error("response timeout"); + let result = policy.should_retry(&Err(err)).await; + assert_eq!( + result, + RetryResult::DoNotRetry, + "response timeout should NOT retry write requests" + ); + } + + #[tokio::test] + async fn response_timeout_read_uses_service_unavailable_counter() { + let mut policy = create_test_policy(); + let mut request = create_test_request(); + policy.before_send_request(&mut request).await; + + let err = create_timeout_error("response timeout"); + let result = policy.should_retry(&Err(err)).await; + assert!(result.is_retry()); + assert_eq!( + policy.service_unavailable_retry_count, 1, + "service_unavailable_retry_count should increment on response timeout for reads" + ); + } + + #[tokio::test] + async fn unknown_io_error_retries_read() { + let mut policy = create_test_policy(); + let mut request = create_test_request(); + policy.before_send_request(&mut request).await; + + let err = create_io_error("some unrelated IO error"); + let result = policy.should_retry(&Err(err)).await; + assert!( + result.is_retry(), + "unknown IO errors should retry read requests" + ); + } + + #[tokio::test] + async fn unknown_io_error_does_not_retry_write() { + let mut policy = create_test_policy(); + let mut request = create_write_request(); + policy.before_send_request(&mut request).await; + + let err = create_io_error("some unrelated IO error"); + let result = policy.should_retry(&Err(err)).await; + assert_eq!( + result, + RetryResult::DoNotRetry, + "unknown IO errors should not retry write requests" + ); + } + + /// **Core bug-fix test.** On a multi-master account with PPCB enabled, a write + /// receiving 403/3 must reach `is_request_eligible_for_partition_level_circuit_breaker`, + /// which internally calls `increment_request_failure_counter_and_check_if_partition_can_failover`. + /// + /// Before the fix, the outer `partition_level_failover_enabled()` guard short-circuited + /// on multi-master accounts, so the counter was never incremented and partition-level + /// failover could never trigger. After the fix, each 403/3 increments the counter. + #[tokio::test] + async fn write_forbidden_on_multi_master_increments_ppcb_counter() { + let mut policy = create_multi_master_ppcb_policy(); + let request = create_ppcb_write_request(); + + policy.operation_type = Some(OperationType::Create); + policy.can_use_multiple_write_locations = true; + policy.location_endpoint = Some("https://test-westus.documents.azure.com".parse().unwrap()); + policy.request = Some(request); + + // Call the method that the 403/3 handler now invokes. + // Default write threshold = 5, so each of the first 5 calls should return false + // (counter not yet exceeded), but still increment the counter. + for i in 1..=5 { + let result = policy.increment_failure_counter_and_check_circuit_breaker_eligibility(); + assert!( + !result, + "Attempt {i}: Expected false (counter {i} <= threshold 5)" + ); + } + + // The 6th call should exceed the threshold (6 > 5) and return true + let result = policy.increment_failure_counter_and_check_circuit_breaker_eligibility(); + assert!( + result, + "Expected true after exceeding write failure threshold (6 > 5)" + ); + } + + /// On a multi-master account with PPCB enabled, once the write failure count + /// exceeds the threshold (default: 5), the next 403/3 should trigger partition-level + /// failover: the partition is marked unavailable and the retry is immediate + /// (Duration::ZERO) without going through account-level failover. + #[tokio::test] + async fn write_forbidden_on_multi_master_triggers_partition_failover_after_threshold() { + let mut policy = create_multi_master_ppcb_policy(); + let request = create_ppcb_write_request(); + + policy.operation_type = Some(OperationType::Create); + policy.can_use_multiple_write_locations = true; + policy.location_endpoint = Some("https://test-westus.documents.azure.com".parse().unwrap()); + policy.request = Some(request); + + // Pre-pump the PPCB counter to just below threshold (5 increments). + // Each call to increment_failure_counter_and_check_circuit_breaker_eligibility() + // increments the counter via increment_request_failure_counter_and_check_if_partition_can_failover. + for _ in 1..=5 { + let _ = policy.increment_failure_counter_and_check_circuit_breaker_eligibility(); + } + + let failover_count_before = policy.failover_retry_count; + + // Now call should_retry_error with a 403/3 error. + // The 6th increment will exceed the threshold → the PPCB path returns + // Retry { after: ZERO } immediately, without calling should_retry_on_endpoint_failure + // (and thus without any HTTP calls). + let error = create_error_with_substatus( + StatusCode::Forbidden, + SubStatusCode::WRITE_FORBIDDEN.value() as u32, + ); + + let result = policy.should_retry_error(&error).await; + + assert!( + result.is_retry(), + "Expected retry after PPCB threshold exceeded" + ); + + // The partition-level path returns Retry { after: ZERO } directly, + // WITHOUT going through should_retry_on_endpoint_failure, so + // failover_retry_count should NOT have increased. + assert_eq!( + policy.failover_retry_count, failover_count_before, + "failover_retry_count should NOT increment when partition-level failover succeeds \ + (partition path bypasses account-level retry)" + ); + } + + /// Without PPCB enabled, `is_request_eligible_for_partition_level_circuit_breaker` + /// should always return false, meaning partition-level failover never triggers + /// for 403/3 on multi-master accounts. + #[tokio::test] + async fn write_forbidden_on_multi_master_without_ppcb_returns_ineligible() { + let manager = create_multi_master_endpoint_manager(); + // PPAF = false, PPCB = false + let partition_manager = GlobalPartitionEndpointManager::new(manager.clone(), false, false); + let mut policy = ClientRetryPolicy::new(manager, partition_manager, None); + + let request = create_ppcb_write_request(); + policy.operation_type = Some(OperationType::Create); + policy.can_use_multiple_write_locations = true; + policy.location_endpoint = Some("https://test-westus.documents.azure.com".parse().unwrap()); + policy.request = Some(request); + + // Even after many calls, should always return false (PPCB disabled) + for i in 1..=8 { + let result = policy.increment_failure_counter_and_check_circuit_breaker_eligibility(); + assert!(!result, "Attempt {i}: Expected false when PPCB is disabled"); + } + } +} diff --git a/sdk/cosmos/azure_data_cosmos/src/retry_policies/metadata_request_retry_policy.rs b/sdk/cosmos/azure_data_cosmos/src/retry_policies/metadata_request_retry_policy.rs new file mode 100644 index 00000000000..a663a9e34fd --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos/src/retry_policies/metadata_request_retry_policy.rs @@ -0,0 +1,687 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +use super::{ + get_substatus_code_from_error, get_substatus_code_from_response, is_non_retryable_status_code, + RetryResult, +}; +use crate::constants::SubStatusCode; +use crate::cosmos_request::CosmosRequest; +use crate::operation_context::OperationType; +use crate::regions::Region; +use crate::retry_policies::resource_throttle_retry_policy::ResourceThrottleRetryPolicy; +use crate::routing::global_endpoint_manager::GlobalEndpointManager; +use azure_core::http::{RawResponse, StatusCode}; +use azure_core::time::Duration; +use std::sync::Arc; +use tracing::trace; + +/// Retry policy for handling metadata request failures. +#[derive(Debug)] +pub(crate) struct MetadataRequestRetryPolicy { + /// An instance of GlobalEndpointManager. + global_endpoint_manager: Arc, + + /// Defines the throttling retry policy that is used as the underlying retry policy. + throttling_retry_policy: ResourceThrottleRetryPolicy, + + /// An instance containing the location endpoint where the partition key + /// range http request will be sent over. + retry_context: Option, + + /// An integer capturing the current retry count on unavailable endpoint. + unavailable_endpoint_retry_count: usize, + + /// Regions excluded from routing for the current request. + excluded_regions: Option>, +} + +/// A helper struct containing the required attributes for metadata retry context. +#[derive(Clone, Debug)] +struct MetadataRetryContext { + /// An integer defining the current retry location index. + retry_location_index: usize, + + /// A boolean flag indicating if the request should retry on preferred locations. + retry_request_on_preferred_locations: bool, +} + +impl MetadataRequestRetryPolicy { + /// Creates a new MetadataRequestRetryPolicy with the specified global endpoint manager. + /// + /// # Summary + /// Initializes a metadata request retry policy that handles transient failures for metadata operations + /// by retrying requests across multiple endpoints when service unavailable or internal server errors occur. + /// The policy integrates with the global endpoint manager to route requests to alternative endpoints. + /// + /// # Arguments + /// * `global_endpoint_manager` - The global endpoint manager for routing requests across regions + /// + /// # Returns + /// A new instance of `MetadataRequestRetryPolicy` configured with: + /// - Maximum unavailable endpoint retries based on preferred location count + /// - Underlying throttling retry policy for 429 responses + /// - Initial retry count set to zero + pub fn new(global_endpoint_manager: Arc) -> Self { + Self { + global_endpoint_manager, + throttling_retry_policy: ResourceThrottleRetryPolicy::new(5, 200, 10), + retry_context: None, + unavailable_endpoint_retry_count: 0, + excluded_regions: None, + } + } + + /// Method that is called before a request is sent to allow the retry policy implementation + /// to modify the state of the request. + /// + /// # Arguments + /// + /// * `request` - The request being sent to the service + pub(crate) async fn before_send_request(&mut self, request: &mut CosmosRequest) { + let _stat = self.global_endpoint_manager.refresh_location(false).await; + + self.excluded_regions = request.excluded_regions.clone().map(|e| e.0); + + // Clear the previous location-based routing directive + request.request_context.clear_route_to_location(); + + if let Some(ref ctx) = self.retry_context { + let mut req_ctx = request.request_context.clone(); + req_ctx.route_to_location_index( + ctx.retry_location_index, + ctx.retry_request_on_preferred_locations, + ); + request.request_context = req_ctx; + } + + let metadata_location_endpoint = self + .global_endpoint_manager + .resolve_service_endpoint(request); + + trace!( + "MetadataRequestThrottleRetryPolicy: Routing the metadata request to: {:?} for operation type: {:?} and resource type: {:?}.", + metadata_location_endpoint, + request.operation_type, + request.resource_type + ); + + request + .request_context + .route_to_location_endpoint(metadata_location_endpoint); + } + + /// Determines whether an HTTP request should be retried based on the response or error + /// + /// This method evaluates the result of an HTTP request attempt and decides whether + /// the operation should be retried, and if so, how long to wait before the next attempt. + /// + /// # Arguments + /// + /// * `response` - A reference to the result of the HTTP request attempt. This can be: + /// - `Ok(RawResponse)` - A successful HTTP response (which may still indicate an error via status code) + /// - `Err(azure_core::Error)` - A network or client-side error + /// + /// # Returns + /// + /// A `RetryResult` indicating the retry decision. + pub(crate) async fn should_retry( + &mut self, + response: &azure_core::Result, + ) -> RetryResult { + match response { + Ok(resp) if resp.status().is_server_error() || resp.status().is_client_error() => { + self.should_retry_response(resp).await + } + Ok(_) => RetryResult::DoNotRetry, + Err(err) => self.should_retry_error(err).await, + } + } + + /// Determines whether to retry a metadata operation that failed with an error. + /// + /// # Summary + /// Evaluates the error to determine if it represents a transient failure (service unavailable, + /// internal server error, lease not found, or database account not found) that can be retried + /// on an alternative endpoint. Falls back to throttling retry logic for 429 responses. + /// + /// # Arguments + /// * `err` - The error that occurred during the metadata operation + /// + /// # Returns + /// A `RetryResult` indicating whether to retry and the delay duration: + /// - `Retry { after: Duration::ZERO }` for retryable metadata errors + /// - Delegates to throttling policy for other errors + pub async fn should_retry_error(&mut self, err: &azure_core::Error) -> RetryResult { + let status_code = err.http_status().unwrap_or(StatusCode::UnknownValue(0)); + let sub_status_code = get_substatus_code_from_error(err); + + let retry_result = self.should_retry_with_status_code(status_code, sub_status_code); + if retry_result.is_retry() { + return retry_result; + } + + self.throttling_retry_policy.should_retry_error(err) + } + + /// Determines whether to retry a metadata operation based on the HTTP response. + /// + /// # Summary + /// Examines the HTTP response status code and sub-status to determine if the failure is transient + /// (503 service unavailable, 500 internal server error, 410 lease not found, 403 database account + /// not found) and can be retried on an alternative endpoint. Delegates to throttling policy for + /// rate limiting (429) responses. + /// + /// # Arguments + /// * `response` - The HTTP response received from the metadata operation + /// + /// # Returns + /// A `RetryResult` indicating whether to retry and the delay duration: + /// - `Retry { after: Duration::ZERO }` for retryable metadata failures + /// - Delegates to throttling policy for rate limiting errors + pub async fn should_retry_response(&mut self, response: &RawResponse) -> RetryResult { + let status_code = response.status(); + let sub_status_code = get_substatus_code_from_response(&response.clone()); + + let retry_result = self.should_retry_with_status_code(status_code, sub_status_code); + if retry_result.is_retry() { + return retry_result; + } + + self.throttling_retry_policy.should_retry_response(response) + } + + /// Core retry decision logic based on status code and sub-status code. + /// + /// # Summary + /// Determines if a metadata request should be retried based on the HTTP status code. + /// Any status code not in the non-retryable whitelist (400, 401, 404, 409, 412, 413) + /// is retried on an alternative endpoint. If retry is allowed, increments the location + /// index to route the next attempt to a different endpoint. + /// + /// # Arguments + /// * `status_code` - The HTTP status code from the response + /// * `sub_status_code` - The Cosmos DB specific sub-status code (reserved for future use) + /// + /// # Returns + /// A `RetryResult`: + /// - `Retry { after: Duration::ZERO }` if the error is retryable and retry count not exceeded + /// - `DoNotRetry` for non-retryable errors or if max retries exceeded + fn should_retry_with_status_code( + &mut self, + status_code: StatusCode, + sub_status_code: Option, + ) -> RetryResult { + if !is_non_retryable_status_code(status_code, sub_status_code) + && self.increment_retry_index_on_unavailable_endpoint_for_metadata_read() + { + return RetryResult::Retry { + after: Duration::ZERO, + }; + } + + RetryResult::DoNotRetry + } + + /// Increments the location index when an unavailable endpoint exception occurs, for any future read requests. + /// + /// # Summary + /// Uses the applicable endpoints from the global endpoint manager to determine the maximum + /// number of retry attempts. Each retry routes the request to the next available endpoint. + /// + /// # Returns + /// + /// A boolean flag indicating if the operation was successful. + fn increment_retry_index_on_unavailable_endpoint_for_metadata_read(&mut self) -> bool { + self.unavailable_endpoint_retry_count += 1; + + let endpoints = self + .global_endpoint_manager + .applicable_endpoints(OperationType::Read, self.excluded_regions.as_ref()); + + if self.unavailable_endpoint_retry_count > endpoints.len() { + trace!( + "MetadataRequestThrottleRetryPolicy: Retry count: {} has exceeded the number of applicable endpoints: {}.", + self.unavailable_endpoint_retry_count, + endpoints.len() + ); + return false; + } + + trace!( + "MetadataRequestThrottleRetryPolicy: Incrementing the metadata retry location index to: {}.", + self.unavailable_endpoint_retry_count + ); + + self.retry_context = Some(MetadataRetryContext { + retry_location_index: self.unavailable_endpoint_retry_count, + retry_request_on_preferred_locations: true, + }); + + true + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::operation_context::OperationType; + use crate::options::ExcludedRegions; + use crate::partition_key::PartitionKey; + use crate::regions::Region; + use crate::resource_context::{ResourceLink, ResourceType}; + use crate::routing::global_endpoint_manager::GlobalEndpointManager; + use azure_core::http::headers::Headers; + use azure_core::http::ClientOptions; + use azure_core::Bytes; + use std::sync::Arc; + + fn create_test_endpoint_manager() -> Arc { + let pipeline = azure_core::http::Pipeline::new( + option_env!("CARGO_PKG_NAME"), + option_env!("CARGO_PKG_VERSION"), + ClientOptions::default(), + Vec::new(), + Vec::new(), + None, + ); + + GlobalEndpointManager::new( + "https://test.documents.azure.com".parse().unwrap(), + vec![Region::from("West US"), Region::from("East US")], + vec![], + pipeline, + ) + } + + fn create_test_endpoint_manager_no_locations() -> Arc { + let pipeline = azure_core::http::Pipeline::new( + option_env!("CARGO_PKG_NAME"), + option_env!("CARGO_PKG_VERSION"), + ClientOptions::default(), + Vec::new(), + Vec::new(), + None, + ); + + GlobalEndpointManager::new( + "https://test.documents.azure.com".parse().unwrap(), + vec![], + vec![], + pipeline, + ) + } + + fn create_test_endpoint_manager_with_preferred_locations() -> Arc { + let pipeline = azure_core::http::Pipeline::new( + option_env!("CARGO_PKG_NAME"), + option_env!("CARGO_PKG_VERSION"), + ClientOptions::default(), + Vec::new(), + Vec::new(), + None, + ); + + GlobalEndpointManager::new( + "https://test.documents.azure.com".parse().unwrap(), + vec![Region::EAST_ASIA, Region::WEST_US, Region::NORTH_CENTRAL_US], + vec![], + pipeline, + ) + } + + fn create_test_policy() -> MetadataRequestRetryPolicy { + let manager = create_test_endpoint_manager(); + MetadataRequestRetryPolicy::new(manager) + } + + fn create_test_policy_no_locations() -> MetadataRequestRetryPolicy { + let manager = create_test_endpoint_manager_no_locations(); + MetadataRequestRetryPolicy::new(manager) + } + + fn create_test_policy_with_preferred_locations() -> MetadataRequestRetryPolicy { + let manager = create_test_endpoint_manager_with_preferred_locations(); + MetadataRequestRetryPolicy::new(manager) + } + + fn create_test_request() -> CosmosRequest { + let resource_link = ResourceLink::root(ResourceType::Documents); + let mut request = CosmosRequest::builder(OperationType::Read, resource_link.clone()) + .partition_key(PartitionKey::from("test")) + .build() + .unwrap(); + + request.request_context.location_endpoint_to_route = + Some("https://test.documents.azure.com".parse().unwrap()); + request + } + + fn create_raw_response(status_code: StatusCode) -> RawResponse { + let headers = Headers::new(); + RawResponse::from_bytes(status_code, headers, Bytes::new()) + } + + fn create_error_with_status(status: StatusCode) -> azure_core::Error { + let response = create_raw_response(status); + azure_core::Error::new( + azure_core::error::ErrorKind::HttpResponse { + status: response.status(), + error_code: None, + raw_response: Some(Box::new(response)), + }, + "Test error", + ) + } + + #[tokio::test] + async fn test_new_policy_initialization() { + let policy = create_test_policy_with_preferred_locations(); + assert_eq!(policy.unavailable_endpoint_retry_count, 0); + assert!(policy.excluded_regions.is_none()); + } + + #[tokio::test] + async fn test_retry_context_none_initially() { + let policy = create_test_policy(); + assert!(policy.retry_context.is_none()); + } + + #[tokio::test] + async fn test_should_retry_service_unavailable_error() { + let mut policy = create_test_policy_no_locations(); + let error = create_error_with_status(StatusCode::ServiceUnavailable); + + let result = policy.should_retry_error(&error).await; + assert!(result.is_retry()); + if let RetryResult::Retry { after } = result { + assert_eq!(after, Duration::ZERO); + } + } + + #[tokio::test] + async fn test_should_retry_internal_server_error() { + let mut policy = create_test_policy_with_preferred_locations(); + let error = create_error_with_status(StatusCode::InternalServerError); + + let result = policy.should_retry_error(&error).await; + assert!(result.is_retry()); + } + + #[tokio::test] + async fn test_should_retry_service_unavailable_response() { + let mut policy = create_test_policy_with_preferred_locations(); + let response = create_raw_response(StatusCode::ServiceUnavailable); + + let result = policy.should_retry_response(&response).await; + assert!(result.is_retry()); + } + + #[tokio::test] + async fn test_should_retry_internal_server_error_response() { + let mut policy = create_test_policy_with_preferred_locations(); + let response = create_raw_response(StatusCode::InternalServerError); + + let result = policy.should_retry_response(&response).await; + assert!(result.is_retry()); + } + + #[tokio::test] + async fn test_should_not_retry_ok_response() { + let mut policy = create_test_policy(); + let response = create_raw_response(StatusCode::Ok); + + let result = policy.should_retry(&Ok(response)).await; + assert!(!result.is_retry()); + } + + #[tokio::test] + async fn test_should_not_retry_created_response() { + let mut policy = create_test_policy(); + let response = create_raw_response(StatusCode::Created); + + let result = policy.should_retry(&Ok(response)).await; + assert!(!result.is_retry()); + } + + #[tokio::test] + async fn test_increment_retry_index_on_unavailable_endpoint() { + let mut policy = create_test_policy_with_preferred_locations(); + let initial_count = policy.unavailable_endpoint_retry_count; + + let result = policy.increment_retry_index_on_unavailable_endpoint_for_metadata_read(); + assert!(result); + assert_eq!(policy.unavailable_endpoint_retry_count, initial_count + 1); + assert!(policy.retry_context.is_some()); + } + + #[tokio::test] + async fn test_increment_retry_exceeds_max_count() { + let mut policy = create_test_policy_no_locations(); + + // With no preferred locations, applicable_endpoints returns 1 (default endpoint). + // Exhaust that single retry attempt. + assert!(policy.increment_retry_index_on_unavailable_endpoint_for_metadata_read()); + + // Second attempt should fail — only the default endpoint was available + let result = policy.increment_retry_index_on_unavailable_endpoint_for_metadata_read(); + assert!(!result); + } + + #[tokio::test] + async fn test_retry_context_set_after_increment() { + let mut policy = create_test_policy_no_locations(); + + policy.increment_retry_index_on_unavailable_endpoint_for_metadata_read(); + + assert!(policy.retry_context.is_some()); + if let Some(ctx) = &policy.retry_context { + assert!(ctx.retry_request_on_preferred_locations); + assert_eq!( + ctx.retry_location_index, + policy.unavailable_endpoint_retry_count + ); + } + } + + #[tokio::test] + async fn test_should_retry_with_ok_result() { + let mut policy = create_test_policy(); + let response = create_raw_response(StatusCode::Ok); + + let result = policy.should_retry(&Ok(response)).await; + assert!(!result.is_retry()); + } + + #[tokio::test] + async fn test_should_retry_with_server_error_result() { + let mut policy = create_test_policy_no_locations(); + let response = create_raw_response(StatusCode::InternalServerError); + + let result = policy.should_retry(&Ok(response)).await; + assert!(result.is_retry()); + } + + #[tokio::test] + async fn test_should_retry_with_error_result() { + let mut policy = create_test_policy_no_locations(); + let error = create_error_with_status(StatusCode::ServiceUnavailable); + + let result = policy.should_retry(&Err(error)).await; + assert!(result.is_retry()); + } + + #[tokio::test] + async fn test_should_not_retry_bad_request() { + let mut policy = create_test_policy(); + let response = create_raw_response(StatusCode::BadRequest); + + let result = policy.should_retry_response(&response).await; + assert!(!result.is_retry()); + } + + #[tokio::test] + async fn test_should_not_retry_not_found() { + let mut policy = create_test_policy(); + let response = create_raw_response(StatusCode::NotFound); + + let result = policy.should_retry_response(&response).await; + assert!(!result.is_retry()); + } + + #[tokio::test] + async fn test_should_not_retry_unauthorized() { + let mut policy = create_test_policy(); + let response = create_raw_response(StatusCode::Unauthorized); + + let result = policy.should_retry_response(&response).await; + assert!(!result.is_retry()); + } + + #[tokio::test] + async fn test_should_not_retry_conflict() { + let mut policy = create_test_policy(); + let response = create_raw_response(StatusCode::Conflict); + + let result = policy.should_retry_response(&response).await; + assert!(!result.is_retry()); + } + + #[tokio::test] + async fn test_should_not_retry_precondition_failed() { + let mut policy = create_test_policy(); + let response = create_raw_response(StatusCode::PreconditionFailed); + + let result = policy.should_retry_response(&response).await; + assert!(!result.is_retry()); + } + + #[tokio::test] + async fn test_should_retry_forbidden_on_another_endpoint() { + let mut policy = create_test_policy_no_locations(); + let response = create_raw_response(StatusCode::Forbidden); + + let result = policy.should_retry_response(&response).await; + assert!(result.is_retry()); + } + + #[tokio::test] + async fn test_should_retry_gone_on_another_endpoint() { + let mut policy = create_test_policy_no_locations(); + let response = create_raw_response(StatusCode::Gone); + + let result = policy.should_retry_response(&response).await; + assert!(result.is_retry()); + } + + #[tokio::test] + async fn test_multiple_retries_increment_counter() { + let mut policy = create_test_policy_no_locations(); + // Reset the counter to 0 to allow multiple increments + policy.unavailable_endpoint_retry_count = 0; + let initial_count = policy.unavailable_endpoint_retry_count; + + let error1 = create_error_with_status(StatusCode::ServiceUnavailable); + let _result1 = policy.should_retry_error(&error1).await; + assert_eq!(policy.unavailable_endpoint_retry_count, initial_count + 1); + + // Can't test second retry as it exceeds max_unavailable_endpoint_retry_count (which is 1) + // So just verify the first increment worked + } + + #[tokio::test] + async fn test_before_send_request_clears_routing() { + let mut policy = create_test_policy(); + let mut request = create_test_request(); + + // Set some routing info + request.request_context.location_index_to_route = Some(5); + + policy.before_send_request(&mut request).await; + + // After before_send_request, routing should be updated + assert!(request.request_context.location_endpoint_to_route.is_some()); + } + + #[tokio::test] + async fn test_retry_context_affects_routing() { + let mut policy = create_test_policy(); + let mut request = create_test_request(); + + // Set up retry context + policy.retry_context = Some(MetadataRetryContext { + retry_location_index: 1, + retry_request_on_preferred_locations: true, + }); + + policy.before_send_request(&mut request).await; + + // Verify the request was updated with retry context + assert!(request.request_context.location_endpoint_to_route.is_some()); + } + + #[tokio::test] + async fn test_policy_debug_format() { + let policy = create_test_policy(); + let debug_str = format!("{:?}", policy); + assert!(debug_str.contains("MetadataRequestRetryPolicy")); + } + + #[test] + fn test_retry_context_clone() { + let ctx = MetadataRetryContext { + retry_location_index: 3, + retry_request_on_preferred_locations: false, + }; + + let cloned = ctx.clone(); + assert_eq!(ctx.retry_location_index, cloned.retry_location_index); + assert_eq!( + ctx.retry_request_on_preferred_locations, + cloned.retry_request_on_preferred_locations + ); + } + + #[tokio::test] + async fn test_before_send_request_captures_excluded_regions() { + let mut policy = create_test_policy_with_preferred_locations(); + let resource_link = ResourceLink::root(ResourceType::Databases); + let mut request = CosmosRequest::builder(OperationType::Read, resource_link) + .partition_key(PartitionKey::from("test")) + .excluded_regions(Some(ExcludedRegions::from_iter([Region::EAST_ASIA]))) + .build() + .unwrap(); + request.request_context.location_endpoint_to_route = + Some("https://test.documents.azure.com".parse().unwrap()); + + policy.before_send_request(&mut request).await; + + assert!(policy.excluded_regions.is_some()); + assert_eq!(policy.excluded_regions.as_ref().unwrap().len(), 1); + assert_eq!( + policy.excluded_regions.as_ref().unwrap()[0], + Region::EAST_ASIA + ); + } + + #[tokio::test] + async fn test_excluded_regions_reduce_retry_attempts() { + let mut policy = create_test_policy_with_preferred_locations(); + // 3 preferred locations: EAST_ASIA, WEST_US, NORTH_CENTRAL_US + // Exclude 2 of them so only 1 endpoint remains + policy.excluded_regions = Some(vec![Region::EAST_ASIA, Region::WEST_US]); + + let error = create_error_with_status(StatusCode::ServiceUnavailable); + + // First retry should succeed — one endpoint is still available + let result = policy.should_retry_error(&error).await; + assert!(result.is_retry()); + + // Second retry should fail — only one non-excluded endpoint was available + let result = policy.should_retry_error(&error).await; + assert!( + !result.is_retry(), + "Expected DoNotRetry after exhausting non-excluded endpoints" + ); + } +} diff --git a/sdk/cosmos/azure_data_cosmos/src/routing/global_endpoint_manager.rs b/sdk/cosmos/azure_data_cosmos/src/routing/global_endpoint_manager.rs new file mode 100644 index 00000000000..99a4131e79a --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos/src/routing/global_endpoint_manager.rs @@ -0,0 +1,785 @@ +//! Concrete (yet unimplemented) GlobalEndpointManager. + +use crate::background_task_manager::BackgroundTaskManager; +use crate::constants::ACCOUNT_PROPERTIES_KEY; +use crate::cosmos_request::CosmosRequest; +use crate::models::AccountProperties; +use crate::operation_context::OperationType; +use crate::regions::Region; +use crate::resource_context::{ResourceLink, ResourceType}; +use crate::routing::async_cache::AsyncCache; +use crate::routing::location_cache::{LocationCache, RequestOperation}; +use azure_core::http::{Context, Pipeline, Response}; +use azure_core::time::Duration; +use azure_core::Error; +use std::collections::HashMap; +use std::fmt::Debug; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Mutex, OnceLock, Weak}; +use tracing::info; +use url::Url; + +/// Type alias for the account refresh callback function. +pub type OnAccountRefreshCallback = Arc; + +/// Default interval (in seconds) at which the background account refresh loop runs. +const BACKGROUND_ACCOUNT_REFRESH_INTERVAL_SECS: i64 = 600; + +/// Manages global endpoint routing, failover, and location awareness for Cosmos DB requests. +/// +/// This component coordinates multi-region request routing by maintaining location cache state, +/// refreshing account properties, and resolving service endpoints based on request characteristics +/// and availability. It handles endpoint discovery, tracks unavailable endpoints, and supports +/// multi-master write configurations. +pub(crate) struct GlobalEndpointManager { + /// The primary default endpoint URL for the Cosmos DB account + default_endpoint: Url, + + /// Thread-safe cache of location information including read/write endpoints and availability status + location_cache: Mutex, + + /// HTTP pipeline for making requests to the Cosmos DB service + pipeline: Pipeline, + + /// Cache for account properties with 600 second TTL to reduce redundant service calls + account_properties_cache: AsyncCache<&'static str, AccountProperties>, + + /// Optional callback invoked when account properties are refreshed via HTTP call. + /// Uses `OnceLock` because the callback is registered exactly once during client + /// construction and never changes afterward. This avoids the need for a `Mutex` + /// (which is error-prone in async code) and is completely lock-free on reads. + on_account_refresh: OnceLock, + + /// Flag indicating if the background connection initialization task is active. + background_account_refresh_active: AtomicBool, + + /// Manages background tasks and signals them to stop when dropped. + background_task_manager: BackgroundTaskManager, + + /// Background account refresh interval in seconds. Default is 10 minutes. + background_account_refresh_interval: Duration, +} + +impl Debug for GlobalEndpointManager { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("GlobalEndpointManager") + .field("default_endpoint", &self.default_endpoint) + .field("location_cache", &self.location_cache) + .field("pipeline", &self.pipeline) + .field("account_properties_cache", &self.account_properties_cache) + .field("on_account_refresh", &"") + .finish() + } +} + +impl GlobalEndpointManager { + /// Creates a new `GlobalEndpointManager` with a `LocationCache` initialized + /// from the provided `default_endpoint` and `preferred_locations`. + /// + /// # Summary + /// Initializes the endpoint manager with a default endpoint, preferred regions for routing, + /// and an HTTP pipeline for communication. Sets up location cache for endpoint management + /// and account properties cache with 600 second TTL. The manager starts with empty endpoint + /// lists until the first account properties refresh populates regional endpoints. + /// + /// # Arguments + /// * `default_endpoint` - The primary Cosmos DB account endpoint URL + /// * `preferred_locations` - Ordered list of preferred Azure regions for request routing + /// * `excluded_regions` - List of regions to exclude from routing + /// * `pipeline` - HTTP pipeline for making service requests + /// + /// # Returns + /// A new `GlobalEndpointManager` instance ready for request routing + pub fn new( + default_endpoint: Url, + preferred_locations: Vec, + excluded_regions: Vec, + pipeline: Pipeline, + ) -> Arc { + let location_cache = Mutex::new(LocationCache::new( + default_endpoint.clone(), + preferred_locations.clone(), + excluded_regions.clone(), + )); + + let account_properties_cache = AsyncCache::new( + Some(Duration::seconds(600)), // Default 10 minutes TTL + ); + + let instance = Arc::new(Self { + default_endpoint, + location_cache, + pipeline, + account_properties_cache, + on_account_refresh: OnceLock::new(), + background_account_refresh_active: AtomicBool::new(false), + background_task_manager: BackgroundTaskManager::new(), + background_account_refresh_interval: Duration::seconds( + BACKGROUND_ACCOUNT_REFRESH_INTERVAL_SECS, + ), + }); + instance.initialize_and_start_background_account_refresh(); + instance + } + + /// Sets a callback to be invoked whenever account properties are refreshed via HTTP call. + /// + /// # Summary + /// Registers a callback function that will be called automatically whenever `refresh_location` + /// fetches new account properties from the service (not when serving from cache). This is useful + /// for updating partition-level failover configurations when account properties change. + /// + /// This method must be called at most once (during client construction). Subsequent calls + /// are silently ignored because the callback slot is backed by a [`OnceLock`]. + /// + /// # Arguments + /// * `callback` - The callback function to invoke with the refreshed account properties + pub fn set_on_account_refresh_callback(&self, callback: OnAccountRefreshCallback) { + // OnceLock::set returns Err if already initialized; we intentionally ignore that + // because the callback is expected to be set exactly once. + let _ = self.on_account_refresh.set(callback); + } + + /// Returns the default hub endpoint URL for the Cosmos DB account. + /// + /// # Summary + /// Retrieves the primary endpoint URL that was configured during manager initialization. + /// This is the main entry point for the Cosmos DB account and is used as a fallback + /// when no preferred regional endpoints are available or configured. + /// + /// # Returns + /// The default endpoint URL as a String + pub fn hub_uri(&self) -> &Url { + &self.default_endpoint + } + + /// Returns the list of available read endpoints. + /// + /// # Summary + /// Retrieves all currently available read endpoints from the location cache. The list + /// includes regional endpoints that can handle read operations and are not marked as + /// unavailable. Initially empty until account properties are fetched and processed. + /// + /// # Returns + /// A vector of endpoint URLs available for read operations + #[allow(dead_code)] + pub fn read_endpoints(&self) -> Vec { + self.location_cache + .lock() + .unwrap() + .read_endpoints() + .to_vec() + } + + /// Returns the list of available account read endpoints. + /// + /// # Summary + /// Alias for `read_endpoints()` that retrieves all currently available read endpoints + /// from the location cache. Provides the same functionality with an alternative name + /// for clarity in account-level operations context. + /// + /// # Returns + /// A vector of endpoint URLs available for read operations + #[allow(dead_code)] + pub fn account_read_endpoints(&self) -> Vec { + self.location_cache + .lock() + .unwrap() + .read_endpoints() + .to_vec() + } + + /// Returns the list of available write endpoints. + /// + /// # Summary + /// Retrieves all currently available write endpoints from the location cache. The list + /// includes regional endpoints that can handle write operations and are not marked as + /// unavailable. For multi-master accounts, this may include multiple regions; for + /// single-master accounts, typically only the write region. Initially empty until + /// account properties are fetched. + /// + /// # Returns + /// A vector of endpoint URLs available for write operations + #[allow(dead_code)] + pub fn write_endpoints(&self) -> Vec { + self.location_cache + .lock() + .unwrap() + .write_endpoints() + .to_vec() + } + + /// Resolves the appropriate service endpoint URL for a given request. + /// + /// # Summary + /// Determines which endpoint should handle the request based on operation type + /// (read vs write), resource type, preferred locations, and endpoint availability. + /// Delegates to the location cache which applies routing logic including regional + /// preferences and failover to available endpoints. + /// + /// # Arguments + /// * `request` - The Cosmos DB request requiring endpoint resolution + /// + /// # Returns + /// The resolved endpoint URL as a String + pub(crate) fn resolve_service_endpoint(&self, request: &CosmosRequest) -> Url { + self.location_cache + .lock() + .unwrap() + .resolve_service_endpoint(request) + } + + /// Returns all endpoints applicable for handling a specific request. + /// + /// # Summary + /// Retrieves the list of endpoints that could potentially handle the request based + /// on its operation type (read or write) and current endpoint availability. Used by + /// retry policies to determine how many alternative endpoints are available for + /// failover attempts. + /// + /// # Arguments + /// * `request` - The Cosmos DB request to evaluate + /// + /// # Returns + /// A vector of applicable endpoint URLs + pub fn applicable_endpoints( + &self, + operation_type: OperationType, + excluded_regions: Option<&Vec>, + ) -> Vec { + self.location_cache + .lock() + .unwrap() + .get_applicable_endpoints(operation_type, excluded_regions) + } + + /// Marks an endpoint as unavailable for read operations. + /// + /// # Summary + /// Flags the specified endpoint as unavailable for read requests in the location cache. + /// This is called by retry policies when read requests fail due to endpoint issues, + /// preventing subsequent read operations from being routed to the failing endpoint. + /// The endpoint may still be used for write operations if not separately marked unavailable. + /// + /// # Arguments + /// * `endpoint` - The endpoint URL to mark as unavailable for reads + pub fn mark_endpoint_unavailable_for_read(&self, endpoint: &Url) { + self.location_cache + .lock() + .unwrap() + .mark_endpoint_unavailable(endpoint, RequestOperation::Read) + } + + /// Marks an endpoint as unavailable for write operations. + /// + /// # Summary + /// Flags the specified endpoint as unavailable for write requests in the location cache. + /// This is called by retry policies when write requests fail due to endpoint issues, + /// preventing subsequent write operations from being routed to the failing endpoint. + /// The endpoint may still be used for read operations if not separately marked unavailable. + /// + /// # Arguments + /// * `endpoint` - The endpoint URL to mark as unavailable for writes + pub fn mark_endpoint_unavailable_for_write(&self, endpoint: &Url) { + self.location_cache + .lock() + .unwrap() + .mark_endpoint_unavailable(endpoint, RequestOperation::Write) + } + + /// Determines if a request can utilize multiple write locations. + /// + /// # Summary + /// Evaluates whether the given request can be routed to multiple write regions based + /// on the request's operation type and resource type. Returns true only for write + /// operations on resources that support multi-master writes (documents and stored + /// procedure executions) when the account is configured for multiple write locations. + /// + /// # Arguments + /// * `request` - The Cosmos DB request to evaluate + /// + /// # Returns + /// `true` if the request can use multiple write locations, `false` otherwise + pub fn can_use_multiple_write_locations(&self, request: &CosmosRequest) -> bool { + !request.is_read_only_request() + && self + .can_support_multiple_write_locations(request.resource_type, request.operation_type) + } + + /// Refreshes account properties and location information from the service. + /// + /// # Summary + /// Fetches the latest Cosmos DB account properties including regional endpoint information + /// and updates the location cache. Uses a Moka cache with 600 second TTL to avoid redundant + /// service calls. If `force_refresh` is true, invalidates the cache to ensure fresh data. + /// The location cache is updated only when new data is fetched (TTL expiry or forced refresh), + /// not when serving cached data. + /// + /// # Arguments + /// * `force_refresh` - If true, invalidates cache and forces fresh fetch from service + /// + /// # Returns + /// `Ok(())` if refresh succeeded, `Err` if fetching account properties failed + pub async fn refresh_location(&self, force_refresh: bool) -> Result<(), Error> { + // If force_refresh is true, invalidate the cache to ensure a fresh fetch + if force_refresh { + self.account_properties_cache + .remove(&ACCOUNT_PROPERTIES_KEY) + .await; + } + + // Flag to track if an HTTP call was made + let http_call_made = AtomicBool::new(false); + + // When TTL expires or cache is invalidated, the async block executes and updates location cache + let account_properties = self + .account_properties_cache + .get( + ACCOUNT_PROPERTIES_KEY, + |_| force_refresh, + || async { + // Fetch latest account properties from service + let account_properties: AccountProperties = + self.get_database_account().await?.into_body().json()?; + + // Mark that we're making an HTTP call + http_call_made.store(true, Ordering::SeqCst); + + // Update location cache with the fetched account properties (only on fresh fetch) + { + let mut cache = self.location_cache.lock().unwrap(); + cache.on_database_account_read(account_properties.clone()); + } + + Ok::(account_properties) + }, + ) + .await?; + + // Invoke the registered callback if an HTTP call was made. + // `OnceLock::get` is lock-free, so this is safe to call from async code. + let was_http_call_made = http_call_made.load(Ordering::SeqCst); + if was_http_call_made { + if let Some(callback) = self.on_account_refresh.get() { + callback(&account_properties); + } + } + + Ok(()) + } + + /// Returns a map of write endpoints indexed by location name. + /// + /// # Summary + /// Retrieves a mapping from Azure region names to their corresponding write endpoint URLs. + /// This provides direct lookup of write endpoints by location, useful for diagnostic + /// and monitoring scenarios. The map reflects the current account configuration and + /// may be empty until account properties are fetched. + /// + /// # Returns + /// A HashMap containing the location names with their corresponding write endpoint URLs + #[allow(dead_code)] + fn available_write_endpoints_by_location(&self) -> HashMap { + self.location_cache + .lock() + .unwrap() + .locations_info + .account_write_endpoints_by_location + .clone() + } + + /// Returns a map of read endpoints indexed by location name. + /// + /// # Summary + /// Retrieves a mapping from Azure region names to their corresponding read endpoint URLs. + /// This provides direct lookup of read endpoints by location, useful for diagnostic + /// and monitoring scenarios. The map reflects the current account configuration and + /// may be empty until account properties are fetched. + /// + /// # Returns + /// A HashMap mapping location names to read endpoint URLs + #[allow(dead_code)] + fn available_read_endpoints_by_location(&self) -> HashMap { + self.location_cache + .lock() + .unwrap() + .locations_info + .account_read_endpoints_by_location + .clone() + } + + /// Determines if the account supports multiple write locations for specific resource and operation types. + /// + /// # Summary + /// Evaluates whether multi-master writes are supported based on account configuration and + /// the specific resource/operation combination. Multi-master writes are supported for + /// Documents (all operations) and StoredProcedures (Execute operation only). Other resource + /// types like Databases, Containers, etc., do not support multi-write even in multi-master accounts. + /// + /// # Arguments + /// * `resource_type` - The type of resource being operated on + /// * `operation_type` - The type of operation being performed + /// + /// # Returns + /// `true` if multi-write is supported for the resource/operation, `false` otherwise + pub(crate) fn can_support_multiple_write_locations( + &self, + resource_type: ResourceType, + operation_type: OperationType, + ) -> bool { + let cache = self.location_cache.lock().unwrap(); + cache.can_support_multiple_write_locations(resource_type, operation_type) + } + + /// Retrieves the Cosmos DB account ("database account") properties from the service. + /// + /// # Summary + /// Makes an HTTP request to fetch account properties including regional endpoint information, + /// consistency settings, and multi-master configuration. Uses the default endpoint for the + /// request and constructs a metadata read operation with appropriate resource link. Called + /// internally by `refresh_location` when cache needs updating. + /// + /// # Returns + /// `Ok(Response)` with account metadata, or `Err` if request failed + pub async fn get_database_account(&self) -> azure_core::Result> { + let resource_link = ResourceLink::root(ResourceType::DatabaseAccount); + let builder = CosmosRequest::builder(OperationType::Read, resource_link.clone()); + let mut cosmos_request = builder.build()?; + let endpoint = self + .location_cache + .lock() + .unwrap() + .resolve_service_endpoint(&cosmos_request); + cosmos_request.request_context.location_endpoint_to_route = Some(endpoint); + let ctx_owned = Context::default().with_value(resource_link); + self.pipeline + .send(&ctx_owned, &mut cosmos_request.into_raw_request(), None) + .await + .map(Into::into) + } + + /// Initializes and starts the background account refresh loop. + /// + /// # Summary + /// Atomically checks and sets the `background_account_refresh_active` flag to ensure only + /// one background refresh task runs at a time. If the flag is already set, the call is a + /// no-op. Otherwise, it spawns a background task via [`BackgroundTaskManager`] that + /// periodically refreshes account properties. The spawned task captures a `Weak` + /// reference to avoid a reference cycle, allowing the `GlobalEndpointManager` to be + /// dropped normally, which in turn cancels the background task. + fn initialize_and_start_background_account_refresh(self: &Arc) { + // Atomically try to set from false to true. + // If it was already true, another thread already started the task. + if self + .background_account_refresh_active + .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst) + .is_err() + { + return; + } + + let weak_self = Arc::downgrade(self); + // Spawn via BackgroundTaskManager so the task is tracked and will be + // canceled when the manager (and thus the client) is dropped. + // We capture a Weak (not Arc) to avoid a reference cycle + // that would prevent the GlobalEndpointManager from ever being dropped. + self.background_task_manager.spawn(Box::pin(async move { + Self::initiate_background_account_refresh_loop(weak_self).await; + })); + } + + /// Runs the background account refresh loop that periodically updates location information. + /// + /// # Summary + /// Executes an infinite loop that sleeps for the configured refresh interval and then + /// calls [`refresh_location`](Self::refresh_location) with `force_refresh: true` to + /// fetch the latest account properties from the service. The loop holds only a + /// [`Weak`] reference to the `GlobalEndpointManager`; if the manager has been dropped + /// (i.e., the `Weak` upgrade fails), the loop exits gracefully. Any errors during + /// refresh are logged but do not terminate the loop. + /// + /// # Arguments + /// * `weak_self` - A weak reference to the owning `GlobalEndpointManager` + async fn initiate_background_account_refresh_loop(weak_self: Weak) { + // Briefly upgrade to read the interval, then release the strong ref + // so it does not keep Self alive across the sleep. + let interval = match weak_self.upgrade() { + Some(strong) => strong.background_account_refresh_interval, + None => return, + }; + + loop { + // Use the runtime-agnostic sleep from azure_core + azure_core::async_runtime::get_async_runtime() + .sleep(interval) + .await; + + // Upgrade the Weak ref for this iteration only. If it fails, the + // manager has been dropped and we should exit. + let strong = match weak_self.upgrade() { + Some(s) => s, + None => { + info!("GlobalEndpointManager: background refresh loop exiting because the client has been dropped."); + return; + } + }; + + info!("GlobalEndpointManager: refresh_location() trying to refresh database account."); + + if let Err(e) = strong.refresh_location(true).await { + tracing::error!("GlobalEndpointManager: initiate_background_account_refresh_loop() - failed to refresh database account. Exception: {}", e); + } + // `strong` is dropped here, releasing the temporary strong ref + // before the next sleep. + } + } + + /// Updates the location cache with the given write and read regions. + /// + /// This is exposed as `pub(crate)` to allow other modules' tests to populate + /// endpoints without requiring a live service call to `refresh_location`. + #[cfg(test)] + pub(crate) fn update_location_cache( + &self, + write_locations: Vec, + read_locations: Vec, + ) { + let _ = self + .location_cache + .lock() + .unwrap() + .update(write_locations, read_locations); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::models::AccountRegion; + use crate::partition_key::PartitionKey; + + fn create_test_pipeline() -> Pipeline { + Pipeline::new( + option_env!("CARGO_PKG_NAME"), + option_env!("CARGO_PKG_VERSION"), + azure_core::http::ClientOptions::default(), + Vec::new(), + Vec::new(), + None, + ) + } + + fn create_test_manager() -> Arc { + GlobalEndpointManager::new( + "https://test.documents.azure.com".parse().unwrap(), + vec![Region::from("West US"), Region::from("East US")], + vec![], + create_test_pipeline(), + ) + } + + fn create_test_request(operation_type: OperationType) -> CosmosRequest { + let resource_link = ResourceLink::root(ResourceType::Documents); + let mut request = CosmosRequest::builder(operation_type, resource_link.clone()) + .partition_key(PartitionKey::from("test")) + .build() + .unwrap(); + + request.request_context.location_endpoint_to_route = + Some("https://test.documents.azure.com".parse().unwrap()); + request + } + + #[tokio::test] + async fn test_new_manager_initialization() { + let manager = create_test_manager(); + assert_eq!( + manager.hub_uri(), + &Url::parse("https://test.documents.azure.com/").unwrap() + ); + } + + #[tokio::test] + async fn test_hub_uri() { + let manager = create_test_manager(); + let hub_uri = manager.hub_uri(); + assert_eq!( + hub_uri, + &Url::parse("https://test.documents.azure.com/").unwrap() + ); + } + + #[tokio::test] + async fn test_resolve_service_endpoint_returns_default() { + let manager = create_test_manager(); + let request = create_test_request(OperationType::Read); + let endpoint = manager.resolve_service_endpoint(&request); + // Should return default endpoint initially + assert_eq!( + endpoint, + Url::parse("https://test.documents.azure.com/").unwrap() + ); + } + + #[tokio::test] + async fn test_read_endpoints_initial_state() { + let manager = create_test_manager(); + let endpoints = manager.read_endpoints(); + // Initial state may be empty until account properties are loaded + // Just verify it returns a valid vector and doesn't panic + let _ = endpoints.len(); + } + + #[tokio::test] + async fn test_write_endpoints_initial_state() { + let manager = create_test_manager(); + let endpoints = manager.write_endpoints(); + // Initial state may be empty until account properties are loaded + // Just verify it returns a valid vector and doesn't panic + let _ = endpoints.len(); + } + + #[tokio::test] + async fn test_mark_endpoint_unavailable_for_read() { + let manager = create_test_manager(); + let endpoint = "https://test.documents.azure.com".parse().unwrap(); + let account_region = AccountRegion { + name: Region::from("West US".to_string()), + database_account_endpoint: "https://test.documents.azure.com".parse().unwrap(), + }; + // Populate the location cache's regions + let _ = manager + .location_cache + .lock() + .unwrap() + .update(vec![account_region.clone()], vec![account_region]); + + // This should not panic + manager.mark_endpoint_unavailable_for_read(&endpoint); + + // The endpoint should still be in the system but marked unavailable + let read_endpoints = manager.read_endpoints(); + assert!(!read_endpoints.is_empty()); + } + + #[tokio::test] + async fn test_mark_endpoint_unavailable_for_write() { + let manager = create_test_manager(); + let endpoint = "https://test.documents.azure.com".parse().unwrap(); + let account_region = AccountRegion { + name: Region::from("West US".to_string()), + database_account_endpoint: "https://test.documents.azure.com".parse().unwrap(), + }; + // Populate the location cache's regions + let _ = manager + .location_cache + .lock() + .unwrap() + .update(vec![account_region.clone()], vec![account_region]); + + // This should not panic + manager.mark_endpoint_unavailable_for_write(&endpoint); + + // The endpoint should still be in the system but marked unavailable + let write_endpoints = manager.write_endpoints(); + assert!(!write_endpoints.is_empty()); + } + + #[tokio::test] + async fn test_can_use_multiple_write_locations_for_read_request() { + let manager = create_test_manager(); + let request = create_test_request(OperationType::Read); + + // Read requests should not use multiple write locations + assert!(!manager.can_use_multiple_write_locations(&request)); + } + + #[tokio::test] + async fn test_can_use_multiple_write_locations_for_write_request() { + let manager = create_test_manager(); + let request = create_test_request(OperationType::Create); + + // Whether this returns true or false depends on account configuration + // Just verify it doesn't panic + let _ = manager.can_use_multiple_write_locations(&request); + } + + #[tokio::test] + async fn test_can_support_multiple_write_locations_for_documents() { + let manager = create_test_manager(); + + // Documents should potentially support multiple write locations + // The actual result depends on account configuration + let _ = manager + .can_support_multiple_write_locations(ResourceType::Documents, OperationType::Create); + } + + #[tokio::test] + async fn test_can_support_multiple_write_locations_for_stored_procedures() { + let manager = create_test_manager(); + + // Stored procedures with Execute operation should potentially support multiple write locations + let _ = manager.can_support_multiple_write_locations( + ResourceType::StoredProcedures, + OperationType::Execute, + ); + } + + #[tokio::test] + async fn test_can_support_multiple_write_locations_for_databases() { + let manager = create_test_manager(); + + // Database operations should not support multiple write locations + let result = manager + .can_support_multiple_write_locations(ResourceType::Databases, OperationType::Create); + + // Databases don't support multi-write + assert!(!result); + } + + #[tokio::test] + async fn test_applicable_endpoints() { + let manager = create_test_manager(); + let endpoints = manager.applicable_endpoints(OperationType::Read, None); + assert!(!endpoints.is_empty()); + } + + #[tokio::test] + async fn test_applicable_excluded_endpoints() { + let manager = create_test_manager(); + // Exclude all regions to test behavior - should still return default endpoint + let excluded_regions: Vec = vec![Region::from("West US"), Region::from("East US")]; + let endpoints = manager.applicable_endpoints(OperationType::Read, Some(&excluded_regions)); + assert!(!endpoints.is_empty()); + let endpoints = + manager.applicable_endpoints(OperationType::Create, Some(&excluded_regions)); + assert!(!endpoints.is_empty()); + } + + #[tokio::test] + async fn test_account_read_endpoints() { + let manager = create_test_manager(); + let endpoints = manager.account_read_endpoints(); + + // Should return the same as read_endpoints + assert_eq!(endpoints, manager.read_endpoints()); + } + + #[tokio::test] + async fn test_available_write_endpoints_by_location() { + let manager = create_test_manager(); + let endpoints_map = manager.available_write_endpoints_by_location(); + + // Should not panic and return a valid map + let _ = endpoints_map.len(); + } + + #[tokio::test] + async fn test_available_read_endpoints_by_location() { + let manager = create_test_manager(); + let endpoints_map = manager.available_read_endpoints_by_location(); + + // Should not panic and return a valid map + let _ = endpoints_map.len(); + } +} diff --git a/sdk/cosmos/azure_data_cosmos/src/session_helpers.rs b/sdk/cosmos/azure_data_cosmos/src/session_helpers.rs index 03211351e81..d83375b9070 100644 --- a/sdk/cosmos/azure_data_cosmos/src/session_helpers.rs +++ b/sdk/cosmos/azure_data_cosmos/src/session_helpers.rs @@ -3,7 +3,7 @@ //! Helpers for merging and managing session tokens across feed ranges. -use crate::feed_range::FeedRange; +use crate::FeedRange; use azure_core::error::ErrorKind; use azure_data_cosmos_driver::models::{SessionToken, SessionTokenSegment}; @@ -99,9 +99,9 @@ fn merge_ranges_with_subsets( // Sort by range size descending: larger ranges (parents) first. // Primary: max_exclusive descending, secondary: min_inclusive ascending. overlapping.sort_by(|(a, _), (b, _)| { - b.max_exclusive - .cmp(&a.max_exclusive) - .then(a.min_inclusive.cmp(&b.min_inclusive)) + b.max_exclusive() + .cmp(a.max_exclusive()) + .then(a.min_inclusive().cmp(b.min_inclusive())) }); let mut processed = Vec::new(); @@ -188,7 +188,7 @@ fn analyze_subsets( ) -> azure_core::Result { // Sort subsets by min_inclusive so adjacent children are always in order let mut sorted_subsets = subsets.to_vec(); - sorted_subsets.sort_by(|a, b| a.1.min_inclusive.cmp(&b.1.min_inclusive)); + sorted_subsets.sort_by(|a, b| a.1.min_inclusive().cmp(b.1.min_inclusive())); for start_idx in 0..sorted_subsets.len() { let mut merged_range = sorted_subsets[start_idx].1.clone(); @@ -354,13 +354,10 @@ pub(crate) fn get_latest_session_token( #[cfg(test)] mod tests { use super::*; - use crate::hash::EffectivePartitionKey; + use azure_data_cosmos_driver::models::effective_partition_key::EffectivePartitionKey as DriverEpk; fn fr(min: &str, max: &str) -> FeedRange { - FeedRange { - min_inclusive: EffectivePartitionKey::from(min), - max_exclusive: EffectivePartitionKey::from(max), - } + FeedRange::new(DriverEpk::from(min), DriverEpk::from(max)) } fn st(s: &str) -> SessionToken { diff --git a/sdk/cosmos/azure_data_cosmos/tests/emulator_tests/cosmos_containers.rs b/sdk/cosmos/azure_data_cosmos/tests/emulator_tests/cosmos_containers.rs index cfebc801a37..367a692bc12 100644 --- a/sdk/cosmos/azure_data_cosmos/tests/emulator_tests/cosmos_containers.rs +++ b/sdk/cosmos/azure_data_cosmos/tests/emulator_tests/cosmos_containers.rs @@ -72,11 +72,13 @@ pub async fn container_crud_simple() -> Result<(), Box> { indexing_policy.indexing_mode.unwrap() ); - let mut query_pager = db_client.query_containers( - Query::from("SELECT * FROM root r WHERE r.id = @id") - .with_parameter("@id", &properties.id)?, - None, - )?; + let mut query_pager = db_client + .query_containers( + Query::from("SELECT * FROM root r WHERE r.id = @id") + .with_parameter("@id", &properties.id)?, + None, + ) + .await?; let mut ids = vec![]; while let Some(db) = query_pager.try_next().await? { ids.push(db.id); @@ -120,11 +122,13 @@ pub async fn container_crud_simple() -> Result<(), Box> { container_client.delete(None).await?; - query_pager = db_client.query_containers( - Query::from("SELECT * FROM root r WHERE r.id = @id") - .with_parameter("@id", &properties.id)?, - None, - )?; + query_pager = db_client + .query_containers( + Query::from("SELECT * FROM root r WHERE r.id = @id") + .with_parameter("@id", &properties.id)?, + None, + ) + .await?; let mut ids = vec![]; while let Some(db) = query_pager.try_next().await? { ids.push(db.id); diff --git a/sdk/cosmos/azure_data_cosmos/tests/emulator_tests/cosmos_databases.rs b/sdk/cosmos/azure_data_cosmos/tests/emulator_tests/cosmos_databases.rs index 2a30d13bed6..43bc5786bb2 100644 --- a/sdk/cosmos/azure_data_cosmos/tests/emulator_tests/cosmos_databases.rs +++ b/sdk/cosmos/azure_data_cosmos/tests/emulator_tests/cosmos_databases.rs @@ -37,7 +37,7 @@ pub async fn database_crud() -> Result<(), Box> { let query = Query::from("SELECT * FROM root r WHERE r.id = @id") .with_parameter("@id", &test_db_id)?; - let mut pager = cosmos_client.query_databases(query.clone(), None)?; + let mut pager = cosmos_client.query_databases(query.clone(), None).await?; let mut ids = Vec::new(); while let Some(db) = pager.try_next().await? { ids.push(db.id); @@ -50,7 +50,7 @@ pub async fn database_crud() -> Result<(), Box> { // We're testing delete, so we want to manually delete the DB rather than letting the clean-up process do it. db_client.delete(None).await?; - let mut pager = cosmos_client.query_databases(query, None)?; + let mut pager = cosmos_client.query_databases(query, None).await?; let mut ids = Vec::new(); while let Some(db) = pager.try_next().await? { ids.push(db.id); diff --git a/sdk/cosmos/azure_data_cosmos/tests/emulator_tests/cosmos_query.rs b/sdk/cosmos/azure_data_cosmos/tests/emulator_tests/cosmos_query.rs index 943641312f0..3e6a9dec618 100644 --- a/sdk/cosmos/azure_data_cosmos/tests/emulator_tests/cosmos_query.rs +++ b/sdk/cosmos/azure_data_cosmos/tests/emulator_tests/cosmos_query.rs @@ -11,12 +11,15 @@ use std::error::Error; use azure_core::http::headers::HeaderValue; use azure_core::http::StatusCode; use azure_data_cosmos::{ + clients::DatabaseClient, constants, options::{OperationOptions, QueryOptions}, - Query, + query::QueryScope, + ContinuationToken, Query, }; use framework::{test_data, MockItem, TestClient}; -use futures::{StreamExt, TryStreamExt}; +use futures::StreamExt; +use serde::de::DeserializeOwned; fn collect_matching_items( items: &[MockItem], @@ -25,6 +28,98 @@ fn collect_matching_items( items.iter().filter(|p| predicate(p)).cloned().collect() } +#[derive(Default)] +struct QueryTestOptions { + max_item_count: Option, + use_continuation_token_resume: bool, +} + +async fn execute_query_test( + db_client: &DatabaseClient, + items: Vec, + query: impl Into, + scope: QueryScope, + expected_items: Vec, + options: QueryTestOptions, +) -> Result<(), Box> +where + T: DeserializeOwned + Send + Eq + std::fmt::Debug + 'static, +{ + let container_client = test_data::create_container_with_items(db_client, items, None).await?; + let query: Query = query.into(); + + let build_options = || -> QueryOptions { + let mut o = QueryOptions::default(); + if let Some(max_item_count) = options.max_item_count { + o = o.with_max_item_count(max_item_count); + } + o + }; + + let mut actual_items = Vec::new(); + + if options.use_continuation_token_resume { + // Fetch one page at a time, taking a continuation token after each + // page and resuming a brand-new iterator from the token. This + // exercises the suspend/resume path end-to-end. + let mut continuation: Option = None; + loop { + let mut query_options = build_options(); + if let Some(token) = continuation.take() { + query_options = query_options.with_continuation_token(token); + } + let mut pages = container_client + .query_items::(query.clone(), scope.clone(), Some(query_options)) + .await? + .into_pages(); + + let Some(page) = pages.next().await else { + break; + }; + let page = page?; + actual_items.extend(page.into_items()); + + // Round-trip the continuation token through string form to + // mimic real usage (e.g. persisting it across processes). + let token = pages.to_continuation_token()?; + let serialized = token.as_str().to_owned(); + let restored = ContinuationToken::from_string(serialized); + // Drop the iterator before checking for termination — we want to + // observe the snapshot taken right after the page was emitted. + drop(pages); + + // The pipeline reports its own terminal state via + // `to_continuation_token` returning a token whose decoded + // snapshot is `Drained`. We can't introspect that here, so we + // detect termination by attempting one more poll on a fresh + // iterator: if it yields no page, we're done. + // + // To avoid an extra round-trip when the snapshot is trivially + // drained, we still always set `continuation` and let the + // planner short-circuit to a `DrainedLeaf`. + continuation = Some(restored); + } + } else { + let mut pages = container_client + .query_items::(query, scope, Some(build_options())) + .await? + .into_pages(); + while let Some(page) = pages.next().await { + actual_items.extend(page?.into_items()); + } + } + + assert_eq!(expected_items, actual_items); + Ok(()) +} + +#[derive(serde::Deserialize, Debug, PartialEq, Eq)] +#[serde(rename_all = "camelCase")] +struct ItemProjection { + id: String, + merge_order: usize, +} + #[tokio::test] #[cfg_attr( not(test_category = "emulator"), @@ -32,18 +127,20 @@ fn collect_matching_items( )] pub async fn single_partition_query_simple() -> Result<(), Box> { TestClient::run_with_unique_db( - async |run_context, db_client| { + async |_, db_client| { let items = test_data::generate_mock_items(10, 10); - let container_client = - test_data::create_container_with_items(db_client, items.clone(), None).await?; + let expected_items = + collect_matching_items(&items, |p| p.partition_key == "partition0"); - let result_items: Vec = run_context - .query_items(&container_client, "select * from docs c", "partition0") - .await?; - assert_eq!( - collect_matching_items(&items, |p| p.partition_key == "partition0"), - result_items - ); + execute_query_test( + db_client, + items, + "select * from docs c", + QueryScope::partition("partition0"), + expected_items, + QueryTestOptions::default(), + ) + .await?; Ok(()) }, @@ -59,10 +156,8 @@ pub async fn single_partition_query_simple() -> Result<(), Box> { )] pub async fn single_partition_query_with_parameters() -> Result<(), Box> { TestClient::run_with_unique_db( - async |run_context, db_client| { + async |_, db_client| { let items = test_data::generate_mock_items(10, 10); - let container_client = - test_data::create_container_with_items(db_client, items.clone(), None).await?; // Find a merge order value in partition1's items let merge_order = items @@ -74,13 +169,17 @@ pub async fn single_partition_query_with_parameters() -> Result<(), Box = run_context - .query_items(&container_client, query, "partition1") - .await?; - assert_eq!( - collect_matching_items(&items, |p| p.merge_order == merge_order), - result_items - ); + let expected_items = collect_matching_items(&items, |p| p.merge_order == merge_order); + + execute_query_test( + db_client, + items, + query, + QueryScope::partition("partition1"), + expected_items, + QueryTestOptions::default(), + ) + .await?; Ok(()) }, @@ -96,22 +195,26 @@ pub async fn single_partition_query_with_parameters() -> Result<(), Box Result<(), Box> { TestClient::run_with_unique_db( - async |run_context, db_client| { + async |_, db_client| { let items = test_data::generate_mock_items(10, 10); - let container_client = - test_data::create_container_with_items(db_client, items.clone(), None).await?; - - let result_items: Vec = run_context - .query_items(&container_client, "select value c.id from c", "partition1") - .await?; - assert_eq!( - items - .iter() - .filter(|p| p.partition_key == "partition1") - .map(|p| p.id.to_string()) - .collect::>(), - result_items - ); + let expected_items = items + .iter() + .filter(|p| p.partition_key == "partition1") + .map(|p| ItemProjection { + id: p.id.to_string(), + merge_order: p.merge_order, + }) + .collect::>(); + + execute_query_test( + db_client, + items, + "select c.id, c.mergeOrder from c", + QueryScope::partition("partition1"), + expected_items, + QueryTestOptions::default(), + ) + .await?; Ok(()) }, @@ -127,27 +230,23 @@ pub async fn single_partition_query_with_projection() -> Result<(), Box Result<(), Box> { TestClient::run_with_unique_db( - async |run_context, db_client| { + async |_, db_client| { let items = test_data::generate_mock_items(10, 2); - let container_client = - test_data::create_container_with_items(db_client, items.clone(), None).await?; - - let result_items: Vec = run_context - .query_items( - &container_client, - "select value c.id from c where c.mergeOrder between 40 and 60", - (), - ) - .await?; - - assert_eq!( - items - .iter() - .filter(|p| p.merge_order >= 40 && p.merge_order <= 60) - .map(|p| p.id.to_string()) - .collect::>(), - result_items - ); + let expected_items = items + .iter() + .filter(|p| p.merge_order >= 40 && p.merge_order <= 60) + .map(|p| p.id.to_string()) + .collect::>(); + + execute_query_test( + db_client, + items, + "select value c.id from c where c.mergeOrder between 40 and 60", + QueryScope::full_container(), + expected_items, + QueryTestOptions::default(), + ) + .await?; Ok(()) }, @@ -161,38 +260,59 @@ pub async fn cross_partition_query_with_projection_and_filter() -> Result<(), Bo not(test_category = "emulator"), ignore = "requires test_category 'emulator'" )] -pub async fn cross_partition_query_with_order_by_fails_without_query_engine( -) -> Result<(), Box> { +pub async fn cross_partition_query_with_order_by_fails() -> Result<(), Box> { TestClient::run_with_unique_db( async |_, db_client| { let items = test_data::generate_mock_items(10, 10); let container_client = test_data::create_container_with_items(db_client, items.clone(), None).await?; - let mut pager = container_client.query_items::( - "select value c.id from c order by c.mergeOrder", - (), - None, - )?; - let result = pager.try_next().await; - - let Err(err) = result else { - panic!("expected an error but got a successful result"); + let Err(err) = container_client + .query_items::( + "select value c.id from c order by c.mergeOrder", + QueryScope::full_container(), + None, + ) + .await + else { + panic!("Expected query to fail due to cross-partition ORDER BY"); }; - assert_eq!(Some(StatusCode::BadRequest), err.http_status()); - - let response = - if let azure_core::error::ErrorKind::HttpResponse { raw_response, .. } = err.kind() - { - raw_response.as_ref().unwrap().clone() - } else { - panic!("expected an HTTP response error"); - }; - let sub_status = response.headers().get_optional_str(&constants::SUB_STATUS); - - // 1004 = CrossPartitionQueryNotServable - assert_eq!(Some("1004"), sub_status); + match err.kind() { + azure_core::error::ErrorKind::HttpResponse { + status, + raw_response, + .. + } => { + assert_eq!( + *status, + StatusCode::BadRequest, + "Expected 400 Bad Request for cross-partition ORDER BY" + ); + let raw_response = raw_response.as_ref().unwrap(); + let body = std::str::from_utf8(raw_response.body()).unwrap(); + #[derive(serde::Deserialize)] + struct ErrorDetail { + code: String, + message: String, + } + let error_detail: ErrorDetail = serde_json::from_str(body).unwrap(); + assert_eq!(error_detail.code, "BadRequest"); + + // Take only the first two lines of the message for comparison, since the full message may contain additional details that could change over time + let clean_message = error_detail + .message + .lines() + .take(2) + .collect::>() + .join("\n"); + assert_eq!( + clean_message, + "Query contains 1 or more unsupported features. Upgrade your SDK to a version that does support the requested features:\nQuery contained OrderBy, which the calling client does not support." + ); + } + _ => panic!("Expected HTTP error response for cross-partition ORDER BY"), + } Ok(()) }, None, @@ -226,7 +346,12 @@ pub async fn query_returns_index_and_query_metrics() -> Result<(), Box("select * from c", "partition0", Some(options))? + .query_items::( + "select * from c", + QueryScope::partition("partition0"), + Some(options), + ) + .await? .into_pages(); // Get the first page and check metrics headers @@ -303,9 +428,6 @@ pub async fn single_partition_query_pagination() -> Result<(), Box> { TestClient::run_with_unique_db( async |_, db_client| { let items = test_data::generate_mock_items(1, 5); - let container_client = - test_data::create_container_with_items(db_client, items.clone(), None).await?; - let expected_items = collect_matching_items(&items, |p| p.partition_key == "partition0"); assert!( @@ -313,37 +435,121 @@ pub async fn single_partition_query_pagination() -> Result<(), Box> { "need multiple items to test pagination" ); - // Force 1 item per page to exercise continuation token pagination - let mut custom_headers = HashMap::new(); - custom_headers.insert(constants::MAX_ITEM_COUNT, HeaderValue::from_static("1")); - let operation = OperationOptions::default().with_custom_headers(custom_headers); - let options = QueryOptions::default().with_operation_options(operation); + execute_query_test( + db_client, + items, + "select * from c", + QueryScope::partition("partition0"), + expected_items, + QueryTestOptions { + max_item_count: Some(1), + use_continuation_token_resume: false, + }, + ) + .await?; - let mut pages = container_client - .query_items::("select * from c", "partition0", Some(options))? - .into_pages(); + Ok(()) + }, + None, + ) + .await +} - let mut all_items = Vec::new(); - let mut page_count = 0; +#[tokio::test] +#[cfg_attr( + not(test_category = "emulator"), + ignore = "requires test_category 'emulator'" +)] +pub async fn cross_partition_query_pagination() -> Result<(), Box> { + TestClient::run_with_unique_db( + async |_, db_client| { + let items = test_data::generate_mock_items(3, 3); - while let Some(page) = pages.next().await { - let page = page?; - assert!( - page.items().len() <= 1, - "expected at most 1 item per page, got {}", - page.items().len() - ); - all_items.extend(page.into_items()); - page_count += 1; - } + execute_query_test( + db_client, + items.clone(), + "select * from c", + QueryScope::full_container(), + items, + QueryTestOptions { + max_item_count: Some(1), + use_continuation_token_resume: false, + }, + ) + .await?; + + Ok(()) + }, + None, + ) + .await +} +#[tokio::test] +#[cfg_attr( + not(test_category = "emulator"), + ignore = "requires test_category 'emulator'" +)] +pub async fn cross_partition_query_suspend_resume() -> Result<(), Box> { + TestClient::run_with_unique_db( + async |_, db_client| { + // Four logical partitions × three items per partition. With a + // page size of one, this exercises both intra-partition and + // cross-partition resume points. + let items = test_data::generate_mock_items(4, 3); + + execute_query_test( + db_client, + items.clone(), + "select * from c", + QueryScope::full_container(), + items, + QueryTestOptions { + max_item_count: Some(1), + use_continuation_token_resume: true, + }, + ) + .await?; + + Ok(()) + }, + None, + ) + .await +} + +#[tokio::test] +#[cfg_attr( + not(test_category = "emulator"), + ignore = "requires test_category 'emulator'" +)] +pub async fn query_rejects_newer_sdk_continuation_token() -> Result<(), Box> { + TestClient::run_with_unique_db( + async |_, db_client| { + let items = test_data::generate_mock_items(1, 1); + let container_client = + test_data::create_container_with_items(db_client, items, None).await?; + + // A `c2.` prefix indicates the token was issued by a future + // SDK version this client does not understand. + let token = ContinuationToken::from_string("c2.something".to_string()); + let options = QueryOptions::default().with_continuation_token(token); + + let Err(err) = container_client + .query_items::( + "select * from c", + QueryScope::full_container(), + Some(options), + ) + .await + else { + panic!("expected newer-SDK token to be rejected"); + }; + let message = err.to_string(); assert!( - page_count >= expected_items.len(), - "expected at least {} pages with max-item-count=1, got {}", - expected_items.len(), - page_count + message.contains("newer SDK") || message.contains("c2"), + "unexpected error: {message}" ); - assert_eq!(expected_items, all_items); Ok(()) }, @@ -357,49 +563,185 @@ pub async fn single_partition_query_pagination() -> Result<(), Box> { not(test_category = "emulator"), ignore = "requires test_category 'emulator'" )] -pub async fn cross_partition_query_pagination() -> Result<(), Box> { +pub async fn query_rejects_server_token_for_cross_partition() -> Result<(), Box> { TestClient::run_with_unique_db( async |_, db_client| { - let items = test_data::generate_mock_items(3, 3); + let items = test_data::generate_mock_items(2, 1); let container_client = - test_data::create_container_with_items(db_client, items.clone(), None).await?; + test_data::create_container_with_items(db_client, items, None).await?; + + // An un-prefixed token is treated as an opaque server + // continuation, which is only valid for trivial (single- + // partition) queries. + let token = ContinuationToken::from_string("opaque-server-blob".to_string()); + let options = QueryOptions::default().with_continuation_token(token); + + let Err(err) = container_client + .query_items::( + "select * from c", + QueryScope::full_container(), + Some(options), + ) + .await + else { + panic!("expected opaque server token to be rejected for cross-partition query"); + }; + let message = err.to_string(); + assert!( + message.contains("opaque server continuation token"), + "unexpected error: {message}" + ); - // Force 1 item per page for cross-partition query - let mut custom_headers = HashMap::new(); - custom_headers.insert(constants::MAX_ITEM_COUNT, HeaderValue::from_static("1")); - let operation = OperationOptions::default().with_custom_headers(custom_headers); - let options = QueryOptions::default().with_operation_options(operation); + Ok(()) + }, + None, + ) + .await +} +#[tokio::test] +#[cfg_attr( + not(test_category = "emulator"), + ignore = "requires test_category 'emulator'" +)] +pub async fn single_partition_query_resumes_with_raw_server_token() -> Result<(), Box> { + use base64::engine::general_purpose::URL_SAFE_NO_PAD; + use base64::Engine as _; + + TestClient::run_with_unique_db( + async |_, db_client| { + // One logical partition × five items so we get multiple pages + // with `max_item_count(1)`. + let items = test_data::generate_mock_items(1, 5); + let expected: Vec = + collect_matching_items(&items, |p| p.partition_key == "partition0"); + assert!( + expected.len() > 1, + "need multiple items to exercise pagination" + ); + + let container_client = + test_data::create_container_with_items(db_client, items, None).await?; + let scope = QueryScope::partition("partition0"); + + // --- Round 1: fetch the first page through the SDK and pull + // the SDK-issued `c1.` token. --- let mut pages = container_client - .query_items::("select * from c", (), Some(options))? + .query_items::( + "select * from c", + scope.clone(), + Some(QueryOptions::default().with_max_item_count(1)), + ) + .await? .into_pages(); - let mut all_items = Vec::new(); - let mut page_count = 0; + let first_page = pages + .next() + .await + .expect("expected at least one page from the server")?; + let mut actual: Vec = first_page.into_items(); - while let Some(page) = pages.next().await { - let page = page?; - assert!( - page.items().len() <= 1, - "expected at most 1 item per page, got {}", - page.items().len() - ); - all_items.extend(page.into_items()); - page_count += 1; - } + let token = pages.to_continuation_token()?; + let raw = token.as_str().to_owned(); + drop(pages); assert!( - page_count > 1, - "expected multiple pages with max-item-count=1, got {}", - page_count + raw.starts_with("c1."), + "expected SDK to emit a c1.-prefixed token, got: {raw}" ); - // Cross-partition ordering is not guaranteed, so just check count + + // Crack the SDK token open. We deliberately couple this test + // to the on-the-wire format so we can recover the underlying + // server continuation without exposing extra public APIs. + // + // Format: `c1.` + base64url-no-pad(JSON of `PipelineNodeState`). + // For a trivial single-partition query the JSON is shaped like + // `{"kind":"request","server_continuation":""}`. + let payload = raw.strip_prefix("c1.").unwrap(); + let json_bytes = URL_SAFE_NO_PAD + .decode(payload) + .expect("c1. payload must be valid base64url-no-pad"); + let snapshot: serde_json::Value = serde_json::from_slice(&json_bytes) + .expect("decoded c1. payload must be valid JSON"); assert_eq!( - items.len(), - all_items.len(), - "expected all items to be returned across pages" + snapshot.get("kind").and_then(|v| v.as_str()), + Some("request"), + "trivial single-partition pipeline should snapshot as a single Request node, got: {snapshot}" + ); + let server_token = snapshot + .get("server_continuation") + .and_then(|v| v.as_str()) + .expect("Request node must carry a server_continuation after the first page") + .to_owned(); + assert!( + !server_token.is_empty(), + "server continuation token should not be empty" ); + assert!( + !server_token.starts_with("c1.") && !server_token.starts_with("c2."), + "server continuation must not look like an SDK token, got: {server_token}" + ); + + // --- Round 2: drain the rest of the query using the raw + // server token directly (no `c1.` prefix). The SDK accepts + // un-prefixed tokens as an opaque server fallback for trivial + // single-partition queries. --- + let mut continuation = Some(ContinuationToken::from_string(server_token)); + let mut page_count: usize = 1; + loop { + let mut options = QueryOptions::default().with_max_item_count(1); + if let Some(t) = continuation.take() { + options = options.with_continuation_token(t); + } + + let mut pages = container_client + .query_items::("select * from c", scope.clone(), Some(options)) + .await? + .into_pages(); + + let Some(page) = pages.next().await else { + break; + }; + let page = page?; + let items_in_page = page.into_items(); + let was_empty = items_in_page.is_empty(); + actual.extend(items_in_page); + page_count += 1; + + let next_token = pages.to_continuation_token()?; + let raw_next = next_token.as_str().to_owned(); + drop(pages); + + // Subsequent SDK-issued tokens must still be `c1.`-prefixed. + assert!( + raw_next.starts_with("c1."), + "follow-up token must remain c1.-prefixed, got: {raw_next}" + ); + + // Decode again to detect end-of-stream: when the inner + // snapshot is `{"kind":"drained"}` we are done. + let payload = raw_next.strip_prefix("c1.").unwrap(); + let json_bytes = URL_SAFE_NO_PAD + .decode(payload) + .expect("c1. payload must be valid base64url-no-pad"); + let snapshot: serde_json::Value = + serde_json::from_slice(&json_bytes).expect("payload must be valid JSON"); + let kind = snapshot.get("kind").and_then(|v| v.as_str()).unwrap_or(""); + if kind == "drained" || was_empty { + break; + } + + // Continue feeding the SDK its own next-token. + continuation = Some(ContinuationToken::from_string(raw_next)); + + assert!( + page_count <= expected.len() + 2, + "fetched more pages ({page_count}) than expected ({})", + expected.len() + ); + } + assert_eq!(expected, actual); Ok(()) }, None, diff --git a/sdk/cosmos/azure_data_cosmos/tests/framework/test_client.rs b/sdk/cosmos/azure_data_cosmos/tests/framework/test_client.rs index 402553ca57b..4971fa41eac 100644 --- a/sdk/cosmos/azure_data_cosmos/tests/framework/test_client.rs +++ b/sdk/cosmos/azure_data_cosmos/tests/framework/test_client.rs @@ -11,6 +11,7 @@ use azure_data_cosmos::clients::ContainerClient; use azure_data_cosmos::fault_injection::FaultInjectionClientBuilder; use azure_data_cosmos::models::{ItemResponse, ThroughputProperties}; use azure_data_cosmos::options::ItemReadOptions; +use azure_data_cosmos::query::QueryScope; use azure_data_cosmos::Region; use azure_data_cosmos::{ clients::DatabaseClient, ConnectionString, CosmosClient, CreateContainerOptions, PartitionKey, @@ -417,7 +418,13 @@ impl TestClient { // Initialize tracing subscriber for logging, if not already initialized. // The error is ignored because it only happens if the subscriber is already initialized. _ = tracing_subscriber::fmt() - .with_env_filter(EnvFilter::from_default_env()) + .with_env_filter( + EnvFilter::builder() + // Tests with intentional failures cause noise, so we set the default level to "off" + // to silence them unless the user explicitly configures it. + .with_default_directive("off".parse().unwrap()) + .from_env_lossy(), + ) .try_init(); let test_client = Self::from_env(options.client_application_region.clone()).await?; @@ -685,7 +692,14 @@ impl TestRunContext { const MAX_BACKOFF: Duration = Duration::from_secs(10); loop { - match container.query_items::(query.clone(), partition_key.clone(), None) { + match container + .query_items::( + query.clone(), + QueryScope::partition(partition_key.clone()), + None, + ) + .await + { Ok(pager) => match pager.try_collect::>().await { Ok(items) => return Ok(items), Err(e) if e.http_status() == Some(StatusCode::NotFound) => { @@ -900,7 +914,7 @@ impl TestRunContext { "SELECT * FROM root r WHERE r.id LIKE 'auto-test-{}'", self.run_id )); - let mut pager = self.client().query_databases(query, None)?; + let mut pager = self.client().query_databases(query, None).await?; let mut ids = Vec::new(); while let Some(db) = pager.try_next().await? { ids.push(db.id); diff --git a/sdk/cosmos/azure_data_cosmos/tests/in_memory_emulator_tests/driver_end_to_end.rs b/sdk/cosmos/azure_data_cosmos/tests/in_memory_emulator_tests/driver_end_to_end.rs index b12f7a8020f..abd9a0c2783 100644 --- a/sdk/cosmos/azure_data_cosmos/tests/in_memory_emulator_tests/driver_end_to_end.rs +++ b/sdk/cosmos/azure_data_cosmos/tests/in_memory_emulator_tests/driver_end_to_end.rs @@ -234,7 +234,7 @@ async fn create_database_and_container_through_driver() { CosmosOperation::create_container(emu_db_ref).with_body(coll_body.clone()); let emu_create_coll = backend .emulator_driver - .execute_operation(emu_create_coll_op, OperationOptions::default()) + .execute_point_operation(emu_create_coll_op, OperationOptions::default()) .await .unwrap(); @@ -244,7 +244,7 @@ async fn create_database_and_container_through_driver() { let real_db_ref = DatabaseReference::from_name(account.clone(), db_name.clone()); let real_op = CosmosOperation::create_container(real_db_ref).with_body(coll_body.clone()); let resp = driver - .execute_operation(real_op, OperationOptions::default()) + .execute_point_operation(real_op, OperationOptions::default()) .await .unwrap(); Some(resp) @@ -349,7 +349,7 @@ async fn delete_item_through_driver() { // ── Verify item is gone (emulator) ─────────────────────────── let emu_read_deleted = backend .emulator_driver - .execute_operation( + .execute_point_operation( CosmosOperation::read_item(ItemReference::from_name( &emu_container, PartitionKey::from("pk1"), @@ -366,7 +366,7 @@ async fn delete_item_through_driver() { // ── Verify item is gone (real) ─────────────────────────────── if let (Some(ref driver), Some(ref real_ctr)) = (&backend.real_driver, &real_container) { let real_read_deleted = driver - .execute_operation( + .execute_point_operation( CosmosOperation::read_item(ItemReference::from_name( real_ctr, PartitionKey::from("pk1"), @@ -498,7 +498,7 @@ async fn read_with_stale_session_token_returns_404_1002() { let real_stale_token = if let (Some(ref driver), Some(ref real_ctr)) = (&backend.real_driver, &real_container) { let seed_result = driver - .execute_operation( + .execute_point_operation( CosmosOperation::create_item(ItemReference::from_name( real_ctr, PartitionKey::from("pk1"), @@ -525,7 +525,7 @@ async fn read_with_stale_session_token_returns_404_1002() { // the emulator routed the seed write to. let emu_seed_result = backend .emulator_driver - .execute_operation( + .execute_point_operation( CosmosOperation::create_item(ItemReference::from_name( &emu_container, PartitionKey::from("pk1"), @@ -553,7 +553,7 @@ async fn read_with_stale_session_token_returns_404_1002() { // ── Emulator ───────────────────────────────────────────────── let emu_err = backend .emulator_driver - .execute_operation( + .execute_point_operation( CosmosOperation::read_item(ItemReference::from_name( &emu_container, PartitionKey::from("pk1"), @@ -587,7 +587,7 @@ async fn read_with_stale_session_token_returns_404_1002() { .clone() .expect("real_stale_token should be set when real driver is available"); let real_err = driver - .execute_operation( + .execute_point_operation( CosmosOperation::read_item(ItemReference::from_name( real_ctr, PartitionKey::from("pk1"), @@ -628,7 +628,7 @@ async fn read_after_split_refreshes_driver_routing_map() { let create = backend .emulator_driver - .execute_operation( + .execute_point_operation( CosmosOperation::create_item(ItemReference::from_name( &emu_container, PartitionKey::from("pk1"), @@ -665,7 +665,7 @@ async fn read_after_split_refreshes_driver_routing_map() { let read = backend .emulator_driver - .execute_operation( + .execute_point_operation( CosmosOperation::read_item(ItemReference::from_name( &emu_container, PartitionKey::from("pk1"), @@ -844,7 +844,7 @@ async fn paused_satellite_converges_to_latest_hub_write() { .unwrap(); driver - .execute_operation( + .execute_point_operation( CosmosOperation::create_item(ItemReference::from_name( &container, PartitionKey::from("pk1"), @@ -864,7 +864,7 @@ async fn paused_satellite_converges_to_latest_hub_write() { .unwrap(); driver - .execute_operation( + .execute_point_operation( CosmosOperation::replace_item(ItemReference::from_name( &container, PartitionKey::from("pk1"), @@ -888,7 +888,7 @@ async fn paused_satellite_converges_to_latest_hub_write() { .build(); let west_read_before_resume = driver - .execute_operation( + .execute_point_operation( CosmosOperation::read_item(ItemReference::from_name( &container, PartitionKey::from("pk1"), @@ -907,7 +907,7 @@ async fn paused_satellite_converges_to_latest_hub_write() { emulator_store.resume_replication("West US"); let west_read_after_resume = driver - .execute_operation( + .execute_point_operation( CosmosOperation::read_item(ItemReference::from_name( &container, PartitionKey::from("pk1"), @@ -989,7 +989,7 @@ async fn create_retries_after_429_throttling() { })) .unwrap(); driver - .execute_operation( + .execute_point_operation( CosmosOperation::create_item(ItemReference::from_name( &container, PartitionKey::from("pk1"), @@ -1011,7 +1011,7 @@ async fn create_retries_after_429_throttling() { let start = std::time::Instant::now(); let create = driver - .execute_operation( + .execute_point_operation( CosmosOperation::create_item(ItemReference::from_name( &container, PartitionKey::from("pk1"), @@ -1032,7 +1032,7 @@ async fn create_retries_after_429_throttling() { assert_eq!(u16::from(create.status().status_code()), 201); let read = driver - .execute_operation( + .execute_point_operation( CosmosOperation::read_item(ItemReference::from_name( &container, PartitionKey::from("pk1"), @@ -1171,7 +1171,7 @@ async fn read_failover_on_503_via_fault_injection() { .unwrap(); let emu_create = emu_driver - .execute_operation( + .execute_point_operation( CosmosOperation::create_item(ItemReference::from_name( &emu_container, PartitionKey::from("pk1"), @@ -1191,7 +1191,7 @@ async fn read_failover_on_503_via_fault_injection() { // ── Read item — should failover from East US → West US ─────── let emu_read = emu_driver - .execute_operation( + .execute_point_operation( CosmosOperation::read_item(ItemReference::from_name( &emu_container, PartitionKey::from("pk1"), @@ -1353,7 +1353,7 @@ async fn try_real_failover_comparison( db_name.clone(), ); driver - .execute_operation( + .execute_point_operation( CosmosOperation::create_database(account.clone()).with_body(db_body), OperationOptions::default(), ) @@ -1366,7 +1366,7 @@ async fn try_real_failover_comparison( })) .ok()?; driver - .execute_operation( + .execute_point_operation( CosmosOperation::create_container(db_ref.clone()).with_body(coll_body), OperationOptions::default(), ) @@ -1380,7 +1380,7 @@ async fn try_real_failover_comparison( // Create item. driver - .execute_operation( + .execute_point_operation( CosmosOperation::create_item(ItemReference::from_name( &container, PartitionKey::from("pk1"), @@ -1394,7 +1394,7 @@ async fn try_real_failover_comparison( // Read item — should failover. let read_result = driver - .execute_operation( + .execute_point_operation( CosmosOperation::read_item(ItemReference::from_name( &container, PartitionKey::from("pk1"), @@ -1406,7 +1406,7 @@ async fn try_real_failover_comparison( // Cleanup. let _ = driver - .execute_operation( + .execute_point_operation( CosmosOperation::delete_database(db_ref), OperationOptions::default(), ) @@ -1574,7 +1574,7 @@ async fn v1_writes_distribute_across_partitions() { let body_bytes = serde_json::to_vec(&body).unwrap(); let resp = backend .emulator_driver - .execute_operation( + .execute_point_operation( CosmosOperation::create_item(ItemReference::from_name( &emu_container, PartitionKey::from(pk.clone()), @@ -1601,7 +1601,7 @@ async fn v1_writes_distribute_across_partitions() { let id = format!("v1-doc-{}", i); let resp = backend .emulator_driver - .execute_operation( + .execute_point_operation( CosmosOperation::read_item(ItemReference::from_name( &emu_container, PartitionKey::from(pk), diff --git a/sdk/cosmos/azure_data_cosmos/tests/in_memory_emulator_tests/dual_backend.rs b/sdk/cosmos/azure_data_cosmos/tests/in_memory_emulator_tests/dual_backend.rs index e0b13b69c02..d8e76d36bf2 100644 --- a/sdk/cosmos/azure_data_cosmos/tests/in_memory_emulator_tests/dual_backend.rs +++ b/sdk/cosmos/azure_data_cosmos/tests/in_memory_emulator_tests/dual_backend.rs @@ -162,7 +162,7 @@ impl DualBackend { let body = serde_json::to_vec(&serde_json::json!({"id": db_name}))?; let op = CosmosOperation::create_database(account.clone()).with_body(body); let result = driver - .execute_operation(op, OperationOptions::default()) + .execute_point_operation(op, OperationOptions::default()) .await?; assert!( result.status().is_success(), @@ -212,7 +212,7 @@ impl DualBackend { }))?; let op = CosmosOperation::create_container(db_ref).with_body(body); let result = driver - .execute_operation(op, OperationOptions::default()) + .execute_point_operation(op, OperationOptions::default()) .await?; assert!( result.status().is_success(), @@ -230,7 +230,7 @@ impl DualBackend { if let (Some(driver), Some(account)) = (&self.real_driver, &self.real_account) { let db_ref = DatabaseReference::from_name(account.clone(), db_name.to_string()); let _ = driver - .execute_operation( + .execute_point_operation( CosmosOperation::delete_database(db_ref), OperationOptions::default(), ) @@ -260,14 +260,14 @@ impl DualBackend { let (emu_op, emu_opts) = build_op(emulator_container); let emu_response = self .emulator_driver - .execute_operation(emu_op, emu_opts) + .execute_point_operation(emu_op, emu_opts) .await?; // Run against real account (if available) let real_response = if let (Some(driver), Some(real_ctr)) = (&self.real_driver, real_container) { let (real_op, real_opts) = build_op(real_ctr); - let resp = driver.execute_operation(real_op, real_opts).await?; + let resp = driver.execute_point_operation(real_op, real_opts).await?; Some(resp) } else { None @@ -297,13 +297,13 @@ impl DualBackend { let (emu_op, emu_opts) = build_op(&self.emulator_account); let emu_response = self .emulator_driver - .execute_operation(emu_op, emu_opts) + .execute_point_operation(emu_op, emu_opts) .await?; let real_response = if let (Some(driver), Some(account)) = (&self.real_driver, &self.real_account) { let (real_op, real_opts) = build_op(account); - let resp = driver.execute_operation(real_op, real_opts).await?; + let resp = driver.execute_point_operation(real_op, real_opts).await?; Some(resp) } else { None diff --git a/sdk/cosmos/azure_data_cosmos_benchmarks/src/lib.rs b/sdk/cosmos/azure_data_cosmos_benchmarks/src/lib.rs index a52b728396b..037d2acf390 100644 --- a/sdk/cosmos/azure_data_cosmos_benchmarks/src/lib.rs +++ b/sdk/cosmos/azure_data_cosmos_benchmarks/src/lib.rs @@ -332,7 +332,7 @@ pub async fn setup_live() -> (Arc, ItemReference) { /// Used during setup to ignore "resource already exists" responses when /// creating the benchmark database, container, and item. fn ignore_conflict( - result: azure_core::Result, + result: azure_core::Result>, ) -> azure_core::Result<()> { match result { Ok(_) => Ok(()), diff --git a/sdk/cosmos/azure_data_cosmos_driver/docs/FEED_OPERATIONS_REQS.md b/sdk/cosmos/azure_data_cosmos_driver/docs/FEED_OPERATIONS_REQS.md new file mode 100644 index 00000000000..34f854b214c --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos_driver/docs/FEED_OPERATIONS_REQS.md @@ -0,0 +1,180 @@ +# Feed Operations — Requirements & Design Primer + +**Crate:** `azure_data_cosmos_driver` +**Scope:** Driver-internal architecture for feed operations (queries, future read-many, change feed) +**Current focus:** `SELECT * [WHERE ]` using natural order + +--- + +## 1. Context + +The driver currently handles only point operations (single request → single response). Feed operations produce multiple pages of results, may span many physical partitions, and require resumable pagination state that survives process boundaries. + +Feed operations must flow through the same execution infrastructure as point operations (region failover, session tokens, retry, diagnostics) without penalizing point-operation latency. We are doing multi-millisecond network I/O per page fetch, so the design optimizes for clarity and correctness over nanosecond-level micro-optimization. + +--- + +## 2. Dataflow Pipeline + +All operations — point and feed — are expressed as a **Dataflow Pipeline**: a tree of nodes where leaf nodes perform I/O and intermediate nodes perform sequencing or aggregation. + +### Structure + +- The pipeline is a **tree**. Nodes own their children. Fan-out creates branching. +- **Leaf nodes** issue a single Cosmos DB request via the existing operation pipeline (retry, failover, auth, transport). +- **Intermediate nodes** orchestrate their children. The first intermediate node type is `SequentialDrain`, which iterates children in EPK order, fully draining one before advancing to the next. +- **Trivial pipelines** (point operations, single-partition feeds) are a single leaf node with no intermediate parent. These must add near-zero overhead compared to today's direct execution path. + +### Pipeline Lifecycle + +- `execute_operation` is called once per page. Each call advances the pipeline by one page of results from one physical partition. +- The pipeline object itself is the in-process iteration state. The consumer of the driver (SDK layer) is responsible for holding the pipeline across calls. +- For cross-process resumption, the pipeline state serializes to a `ContinuationToken` string. On resume, the token reconstitutes a pipeline at the correct position. +- It is cleanest to unify these: `ContinuationToken` holds the live pipeline object in-process, and produces the serialized string form on demand. + +### Future Node Types (Design For, Don't Implement Yet) + +- **UnorderedMerge**: concurrent fan-out, results returned in arrival order (Read Many). +- **StreamingOrderedMerge**: k-way merge of pre-sorted partition streams (streaming ORDER BY). +- **BufferedOrderedMerge**: collect all results, then sort (non-streaming ORDER BY). +- **HybridSearch**: issues multiple distinct sub-queries (e.g., vector similarity + full-text keyword) against different child pipelines, then combines/re-ranks their results. Demonstrates that an intermediate node may have heterogeneous children with different semantics. +- **Aggregate**: client-side aggregation across partitions. + +--- + +## 3. Key Invariants + +### Ordering + +When no `ORDER BY` is specified, the driver guarantees results in **(EPK, RID) ascending order**. Within each physical partition, the server returns items in ascending RID order. Across partitions, the driver iterates in ascending EPK order. This is a driver-level guarantee for `SELECT *` queries. + +### Page Boundaries & Suspension + +- For the initial `SequentialDrain` implementation, suspension occurs at page boundaries. A continuation token for this node type only needs to track which partition is active and the server's opaque page token for that partition. +- The continuation token design must allow future node types to store intra-page progress (e.g., a streaming ORDER BY merge may suspend mid-page when its output buffer is full but source partitions are partially consumed). +- A given server continuation token only guarantees you get the *next* page of results from that partition — even if the SDK presents a per-item iterator to the user. + +### Fan-Out Limit + +Cross-partition queries are expensive by design. Containers may have hundreds of thousands of physical partitions; unbounded fan-out is dangerous from a performance and scalability perspective. + +- **Max fan-out**: The pipeline refuses to plan an operation spanning more than N physical partitions. Default: **100**. Configurable by the caller for workloads that intentionally query broadly. +- **Max concurrency**: A separate limit on concurrent in-flight requests within a single pipeline execution. Not needed for the initial `SELECT * WHERE` implementation (sequential drain uses concurrency = 1) but the limit must exist as a configuration point for future concurrent node types. + +### Partition Targeting + +An operation targets the key space in one of three mutually exclusive ways: + +1. **No partition scope** — account/database-level operations. +2. **Logical partition key** — point operations and single-partition feeds. Routes via the gateway using the PK header. No EPK headers. No fan-out. +3. **Feed range (EPK range)** — cross-partition feeds. Resolved to physical partition(s) at plan time. The full container is just the special case of `[min_epk, max_epk)`. + +These are mutually exclusive at the type level — not a runtime check. + +--- + +## 4. Continuation Token & Resumption + +### Dual Nature + +The `ContinuationToken` type serves two roles: + +1. **In-process**: holds the live pipeline state. The SDK keeps it across `execute_operation` calls. No serialization needed per page. +2. **Cross-process**: serializes to an opaque string (base64url-encoded JSON). Safe to store in databases, send to browsers, carry across SDK upgrades. + +### Token Properties + +- **Durable across SDK versions.** Newer SDKs must deserialize tokens from older SDKs. Version field is the option of last resort. +- **O(1) size for sequential drain.** Only the active partition's EPK bounds and server continuation are stored. Drained partitions are reconstructed from the EPK cursor on resume. +- **Bound to the operation.** Tokens include a container RID and operation kind. Replaying a token against a different container or operation type is rejected. +- **Survives partition topology changes.** Tokens store EPK bounds, not physical partition IDs. Splits and merges are handled by re-resolving EPK bounds to current partitions. + +### What the Token Does NOT Encode + +- Query text or parameters (caller must supply an equivalent operation). +- Session tokens or consistency state. +- Per-partition state for all partitions (only the cursor position for sequential drain). + +--- + +## 5. Pipeline Repair (Splits) + +Physical partitions can split at any time. The pipeline must handle this transparently. + +### Leaf Node Invariant + +At all times, a leaf node targets **one specific physical partition** and **one EPK range** that is contained within that partition and does not overlap with any of its peer leaf nodes. A leaf node can only issue one request, so it is impossible for it to target multiple physical partitions. + +### Splits Break the Invariant + +When a physical partition splits, a leaf node's EPK range suddenly covers two or more new physical partitions. The pipeline detects this via a 410 (PartitionIsGone) response — either a full page is returned successfully or a 410 is returned; this never occurs mid-page. + +The leaf node is responsible for **splitting itself** to restore the invariant: + +1. Invalidate the cached partition map for the container. +2. Re-resolve the leaf's EPK range to the new physical partition(s). +3. The single leaf becomes multiple leaves in the parent's children list (the parent must obviously cooperate with this), each targeting one of the new physical partitions with a non-overlapping sub-range of the original EPK range. Depth of the tree remains the same. +4. Execution resumes against the correct new leaf. + +### Merges Do Not Require Repair + +After a merge, multiple leaf nodes may point to different EPK ranges on the same physical partition. This is acceptable — the leaf still targets a single partition and uses EPK min/max headers to scope its request to its intended slice. No pipeline restructuring is needed. (Consolidating redundant leaves after a merge is a potential future optimization but is out of scope to avoid complicating the design.) + +--- + +## 6. Current Implementation Focus + +The initial implementation targets `SELECT * [WHERE ]` queries: + +- **Single-partition**: trivial pipeline (one leaf node). The server evaluates the full SQL including any WHERE clause. Paginated via server continuations. +- **Cross-partition**: `SequentialDrain` intermediate node over N leaf nodes (one per physical partition). Drains partitions in EPK order. No query plan fetch required for passthrough SELECT/WHERE. + +### What This Exercises + +- Partition key range resolution and caching. +- Sequential traversal across partitions in EPK order. +- EPK range scoping via request headers. +- Paginated reads within each partition. +- Continuation token serialization, resume, and topology-change survival. +- Integration with the existing operation pipeline for each sub-request. +- Pipeline repair on partition splits/merges. + +--- + +## 7. Design Boundaries + +### The Driver Does NOT: + +- Deserialize item bodies. It returns raw bytes per item; the SDK handles deserialization. +- Create telemetry spans. It returns structured diagnostics data; the SDK creates OpenTelemetry spans. +- Own the iteration lifetime for multi-page feeds. It executes one page per call; the SDK loops. +- Fetch or interpret backend query plans (for the current SELECT/WHERE scope). + +### Item Body Opacity + +For `SequentialDrain`, item bodies are fully opaque binary payloads. The pipeline does not inspect them — ordering is already established by the backend. + +Future node types (e.g., streaming ORDER BY, hybrid search) may require partial parsing of item bodies. The backend query plan can rewrite the query to use a standardized envelope (promoting ordering keys to top-level fields and demoting the raw user document to a `payload` field). This varied-shape pattern must be considered in the overall design direction, but does not need to be accommodated in the current implementation. + +### The Driver DOES: + +- Plan the pipeline (determine targeting, resolve partitions, build the node tree). +- Execute one page per call through the existing retry/failover infrastructure. +- Produce and consume continuation tokens. +- Repair the pipeline on topology changes (splits/merges). +- Enforce fan-out limits. +- Collect per-node diagnostics for the SDK to surface. + +--- + +## 8. Future Considerations (Inform Design, Don't Implement) + +These capabilities must be achievable without redesigning the pipeline model: + +- **Streaming ORDER BY**: k-way merge of partition streams. Requires fetching a backend query plan to determine sort keys. New intermediate node type. +- **Buffered ORDER BY**: collect all partition results, sort client-side. Same query plan requirement. Different intermediate node. +- **Vector / Hybrid Search**: may require preliminary requests to fetch full-text statistics before issuing the main query. Multi-phase pipeline execution. +- **Read Many Items**: fan-out by (ID, PK) pairs grouped by partition. Concurrent leaf execution with an unordered merge intermediate node. +- **Change Feed**: per-range continuation tokens (O(N) token size, unlike sequential drain's O(1)). Different resumption semantics. + +The pipeline's tree structure, typed node hierarchy, and separation of planning from execution accommodate all of these as new node types and planning strategies without changing the core execution loop. diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/cosmos_driver.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/cosmos_driver.rs index be27bca493a..0e79680df75 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/src/driver/cosmos_driver.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/cosmos_driver.rs @@ -7,20 +7,31 @@ use crate::{ diagnostics::{ DiagnosticsContextBuilder, PipelineType, TransportHttpVersion, TransportSecurity, }, - driver::routing::{ - partition_endpoint_state::PartitionFailoverConfig, - partition_key_range_id::PartitionKeyRangeId, session_manager::SessionManager, - CosmosEndpoint, LocationStateStore, + driver::{ + cache::{PartitionKeyRangeCache, PkRangeFetchResult}, + dataflow::{ + planner, query_plan::QueryPlan, CachedTopologyProvider, OperationPlan, + PartitionRoutingRefresh, PipelineContext, PipelineNodeState, RequestExecutor, + RequestTarget, TopologyProvider, + }, + pipeline::operation_pipeline::OperationOverrides, + routing::{ + partition_endpoint_state::PartitionFailoverConfig, + partition_key_range_id::PartitionKeyRangeId, session_manager::SessionManager, + CosmosEndpoint, LocationStateStore, + }, }, models::{ effective_partition_key::EffectivePartitionKey, AccountEndpoint, AccountReference, - ActivityId, ContainerProperties, ContainerReference, CosmosOperation, DatabaseProperties, - DatabaseReference, PartitionKey, ResourceType, + ActivityId, ContainerProperties, ContainerReference, ContinuationToken, CosmosOperation, + DatabaseProperties, DatabaseReference, OperationTarget, PartitionKey, ResolvedToken, + ResourceType, }, options::{ ConnectionPoolOptions, DiagnosticsOptions, DriverOptions, OperationOptions, OperationOptionsView, ThroughputControlGroupSnapshot, }, + CosmosResponse, }; use arc_swap::ArcSwap; use futures::future::BoxFuture; @@ -31,7 +42,7 @@ use std::time::Duration; use url::Url; use super::{ - cache::{parse_pk_ranges_response, AccountRegion, PartitionKeyRangeCache, PkRangeFetchResult}, + cache::{parse_pk_ranges_response, AccountRegion}, transport::{ cosmos_headers, cosmos_transport_client::HttpRequest, is_emulator_host, request_signing, uses_dataplane_pipeline, AuthorizationContext, CosmosTransport, @@ -39,6 +50,64 @@ use super::{ CosmosDriverRuntime, }; +struct DriverRequestExecutor<'a> { + driver: &'a CosmosDriver, + options: &'a OperationOptions, +} + +fn request_target_overrides( + target: RequestTarget, + continuation: Option, +) -> OperationOverrides { + match target { + RequestTarget::LogicalPartitionKey(pk) => OperationOverrides { + partition_key: Some(pk), + continuation, + ..Default::default() + }, + RequestTarget::PartitionKeyRange { + partition_key_range_id, + .. + } => OperationOverrides { + partition_key_range_id: Some(partition_key_range_id), + continuation, + ..Default::default() + }, + RequestTarget::EffectivePartitionKeyRange { + range, + partition_key_range_id, + } => OperationOverrides { + partition_key_range_id: Some(partition_key_range_id), + feed_range: Some(range), + continuation, + ..Default::default() + }, + RequestTarget::NonPartitioned => OperationOverrides { + continuation, + ..Default::default() + }, + } +} + +impl RequestExecutor for DriverRequestExecutor<'_> { + fn execute_request<'a>( + &'a mut self, + operation: &'a CosmosOperation, + target: RequestTarget, + _partition_routing_refresh: PartitionRoutingRefresh, + continuation: Option, + ) -> BoxFuture<'a, azure_core::Result> { + let driver = self.driver; + let overrides = request_target_overrides(target, continuation); + + Box::pin(async move { + driver + .execute_operation_direct(operation, overrides, self.options) + .await + }) + } +} + /// Cosmos DB driver instance. /// /// A driver represents a connection to a specific Cosmos DB account. It is created @@ -615,7 +684,7 @@ impl CosmosDriver { let options = OperationOptions::default(); let db_result = self - .execute_operation( + .execute_point_operation( CosmosOperation::read_database(db_ref.clone()), options.clone(), ) @@ -630,7 +699,7 @@ impl CosmosDriver { })?; let container_result = self - .execute_operation( + .execute_point_operation( CosmosOperation::read_container_by_name(db_ref, container_name.to_owned()), options, ) @@ -667,7 +736,7 @@ impl CosmosDriver { let options = OperationOptions::default(); let db_result = self - .execute_operation( + .execute_point_operation( CosmosOperation::read_database(db_ref.clone()), options.clone(), ) @@ -681,7 +750,7 @@ impl CosmosDriver { .unwrap_or_else(|| db_rid.to_owned()); let container_result = self - .execute_operation( + .execute_point_operation( CosmosOperation::read_container_by_rid(db_ref, container_rid.to_owned()), options, ) @@ -984,7 +1053,10 @@ impl CosmosDriver { ); let options = OperationOptions::default().with_custom_headers(custom_headers); - match self.execute_operation(operation, options).await { + match self + .execute_operation_direct(&operation, OperationOverrides::default(), &options) + .await + { Ok(response) => { let etag = response.headers().etag.as_ref().map(|e| e.to_string()); @@ -1083,7 +1155,10 @@ impl CosmosDriver { // Need both a container reference and a partition key. let container = operation.container()?; - let partition_key = operation.partition_key()?; + let partition_key = match operation.target() { + OperationTarget::PartitionKey(ref pk) => pk, + _ => return None, + }; self.pk_range_cache .resolve_partition_key_range_id(container, partition_key, false, |c, cont| { @@ -1095,23 +1170,32 @@ impl CosmosDriver { /// Executes a Cosmos DB operation. /// - /// This method computes effective options by merging the provided operation options - /// with driver and runtime defaults, then executes the operation. + /// This method executes an operation by planning it first and then immediately + /// executing one page. This is sufficient for operations with trivial plans, + /// such as point operations and single-partition queries. + /// However, if planning is complicated and multiple pages are going to be requested, + /// in that case, the caller should use the [`plan_operation`](Self::plan_operation) + /// method to build a [`OperationPlan`] and then call [`execute_plan`](Self::execute_plan) + /// for each page of the plan. + /// Retaining the [`OperationPlan`] allows the caller to resume execution from a + /// previous page, maintaining all state, and avoiding unnecessary replanning + /// and continuation token management. /// /// # Parameters /// - /// - `operation`: The operation to execute - /// - `options`: Operation-specific options that override driver and runtime defaults + /// - `operation`: The operation to execute. + /// - `options`: Operation-specific options that override driver and runtime defaults. /// /// # Returns /// - /// Returns a [`crate::models::CosmosResponse`] on success. + /// Returns `Ok(Some(response))` when a page of results is produced, or + /// `Ok(None)` when the pipeline is fully drained (no more pages). /// /// # Errors /// /// Returns an error if: - /// - The account has no authentication configured - /// - The resource reference cannot produce a valid path + /// - The driver has not been initialized + /// - Planning fails (e.g. invalid operation target, backend query plan error) /// - The HTTP request fails /// /// # Example @@ -1132,12 +1216,12 @@ impl CosmosDriver { /// /// let driver = runtime.get_or_create_driver(account, None).await?; /// - /// // Execute operations with operation-specific options that override defaults + /// // Point operation: plan and execute in one call. /// let options = OperationOptionsBuilder::new() /// .with_content_response_on_write(ContentResponseOnWrite::Disabled) /// .build(); /// - /// // let result = driver.execute_operation(operation, options).await?; + /// // let result = driver.execute_operation(operation, options, None).await?; /// # Ok(()) /// # } /// ``` @@ -1145,7 +1229,53 @@ impl CosmosDriver { &self, operation: CosmosOperation, options: OperationOptions, + ) -> azure_core::Result> { + // TODO: This boxing is a temporary fix to avoid a large future. + // We need to do some refactoring here to shrink the future size and avoid this heap allocation if possible. + Box::pin(async { + let container = operation.container().cloned(); + let mut plan = self.plan_operation(operation, &options, None).await?; + self.execute_plan(&mut plan, container, options).await + }) + .await + } + + /// Executes a point operation (read/write item, read database, etc.) without a pre-planned pipeline. + /// + /// This is a convenience method around [`execute_operation`](CosmosDriver::execute_operation) that asserts at debug-time that the operation + /// does not return an empty page. + pub async fn execute_point_operation( + &self, + operation: CosmosOperation, + options: OperationOptions, ) -> azure_core::Result { + match self.execute_operation(operation, options).await { + Ok(Some(r)) => Ok(r), + Ok(None) => { + if cfg!(debug_assertions) { + panic!("point operation returned an empty page") + } + Err(azure_core::Error::with_message( + azure_core::error::ErrorKind::Other, + "internal error: point operation returned an empty page", + )) + } + Err(e) => Err(e), + } + } + + /// Executes a single page of a pre-planned operation using the given plan and options. + /// + /// This function mutates the plan in place to account for any changes that occur during execution + /// (e.g. topology repairs, advancing page state, etc.). + /// After this returns, the plan may be executed again to fetch the next page of results, if any. + /// Once this returns `None`, there are no more pages to fetch, and the operation is complete. + pub async fn execute_plan( + &self, + plan: &mut OperationPlan, + container: Option, + options: OperationOptions, + ) -> azure_core::Result> { if !self.initialized.load(Ordering::Acquire) { let endpoint = AccountEndpoint::from(self.options.account()); return Err(azure_core::Error::with_message( @@ -1156,10 +1286,43 @@ impl CosmosDriver { ), )); } - tracing::debug!("operation started"); + tracing::debug!("plan execution started"); + + let mut executor = DriverRequestExecutor { + driver: self, + options: &options, + }; + + let mut topology = container.map(|c| { + CachedTopologyProvider::new(&self.pk_range_cache, c, |container, continuation| { + self.fetch_pk_ranges_from_service(container, continuation) + }) + }); + + let mut context = PipelineContext::new( + &mut executor, + topology.as_mut().map(|t| t as &mut dyn TopologyProvider), + ); + + plan.pipeline.next_page(&mut context).await + } + + async fn execute_operation_direct( + &self, + operation: &CosmosOperation, + overrides: OperationOverrides, + options: &OperationOptions, + ) -> azure_core::Result { + tracing::debug!( + operation_type = ?operation.operation_type(), + resource_type = ?operation.resource_type(), + resource_reference = ?operation.resource_reference(), + overrides = ?overrides, + body_length = operation.body().map(|b| b.len()), + "executing operation"); // Step 1: Build the single OperationOptionsView for layered resolution. - let effective_options = self.operation_options_view(&options); + let effective_options = self.operation_options_view(options); // Step 2: Resolve effective throughput control group (if any). let effective_control_group = match operation.container() { @@ -1199,7 +1362,7 @@ impl CosmosDriver { // When partition-level failover is enabled, resolving the range ID // before the first attempt lets the pipeline apply partition overrides // from the very first request instead of only after the first retry. - let pre_resolved_pk_range_id = self.pre_resolve_partition_key_range_id(&operation).await; + let pre_resolved_pk_range_id = self.pre_resolve_partition_key_range_id(operation).await; // Step 6: Select the adaptive transport context for the chosen pipeline let transport = self.transport(); @@ -1240,7 +1403,8 @@ impl CosmosDriver { // Step 8: Execute via the new operation pipeline super::pipeline::operation_pipeline::execute_operation_pipeline( - &operation, + operation, + overrides, &effective_options, options.custom_headers(), self.location_state_store.as_ref(), @@ -1297,7 +1461,7 @@ impl CosmosDriver { /// // Use the resolved container for item operations /// let item = ItemReference::from_name(&container, PartitionKey::from("pk1"), "doc1"); /// let result = driver - /// .execute_operation(CosmosOperation::read_item(item), OperationOptions::default()) + /// .execute_point_operation(CosmosOperation::read_item(item), OperationOptions::default()) /// .await?; /// # Ok(()) /// # } @@ -1361,6 +1525,113 @@ impl CosmosDriver { Ok(resolved.as_ref().clone()) } + /// Plans the execution of a Cosmos DB operation. + /// + /// For trivial operations (non-query or single-partition), returns a + /// singleton pipeline immediately. For cross-partition queries, fetches a + /// query plan from the backend and builds a fan-out pipeline. + /// + /// `continuation` optionally provides resume state from a prior call. Two + /// kinds of tokens are accepted: + /// + /// - SDK-issued tokens (`c1.…`) carry a serialized snapshot of the + /// previous pipeline's state and can resume any operation. + /// - Opaque server-issued tokens (no `c.` prefix) are accepted only + /// for trivial operations; passing one to a cross-partition query + /// returns a [`DataConversion`](azure_core::error::ErrorKind::DataConversion) + /// error. + pub async fn plan_operation( + &self, + operation: CosmosOperation, + options: &OperationOptions, + continuation: Option<&ContinuationToken>, + ) -> azure_core::Result { + if !self.initialized.load(Ordering::Acquire) { + let endpoint = AccountEndpoint::from(self.options.account()); + return Err(azure_core::Error::with_message( + azure_core::error::ErrorKind::Other, + format!( + "CosmosDriver for {endpoint} has not been initialized; call initialize() or \ + use CosmosDriverRuntime::get_or_create_driver() which initializes automatically" + ), + )); + } + + tracing::debug!(operation_type = ?operation.operation_type(), resource_type = ?operation.resource_type(), resource_reference = ?operation.resource_reference(), "planning operation"); + + // Share the operation across every Request node in the resulting plan. + // Per-Request differences are layered on at execution time via + // OperationOverrides; the operation itself is never mutated. + let operation = Arc::new(operation); + + // Resolve the continuation token (if any) into a planner-ready resume + // state. Server-issued tokens are only valid for trivial operations. + let resume_state = match continuation { + None => None, + Some(token) => match token.resolve()? { + ResolvedToken::ClientV1(state) => Some(state), + ResolvedToken::ServerOpaque(server_token) => { + if !operation.is_trivial() { + return Err(azure_core::Error::with_message( + azure_core::error::ErrorKind::DataConversion, + "an opaque server continuation token cannot be used to resume a \ + cross-partition query; use the SDK-issued continuation token from \ + FeedPageIterator::to_continuation_token()", + )); + } + Some(PipelineNodeState::Request { + server_continuation: Some(server_token), + }) + } + }, + }; + + // Trivial plan: anything that isn't a cross-partition query. + if operation.is_trivial() { + let pipeline = planner::build_trivial_pipeline(operation, resume_state)?; + return Ok(OperationPlan::new(pipeline)); + } + + // Cross-partition query: fetch query plan from backend. + let container = operation.container().ok_or_else(|| { + azure_core::Error::with_message( + azure_core::error::ErrorKind::Other, + "cross-partition query requires a container reference", + ) + })?; + + let query_plan_operation = CosmosOperation::query_plan(container.clone()) + .with_body(operation.body().unwrap_or_default().to_vec()); + + let response = self + .execute_operation_direct( + &query_plan_operation, + OperationOverrides::default(), + options, + ) + .await?; + + let query_plan: QueryPlan = serde_json::from_slice(response.body()).map_err(|e| { + azure_core::Error::with_message( + azure_core::error::ErrorKind::DataConversion, + format!("failed to parse query plan response: {e}"), + ) + })?; + + // Build the fan-out pipeline using the query plan. + let container_ref = container.clone(); + let mut topology = CachedTopologyProvider::new( + &self.pk_range_cache, + container_ref, + |container, continuation| self.fetch_pk_ranges_from_service(container, continuation), + ); + + let pipeline = + planner::build_sequential_drain(&query_plan, &mut topology, &operation, resume_state) + .await?; + Ok(OperationPlan::new(pipeline)) + } + /// Returns all partition key ranges for a container, ordered by min EPK. /// /// Uses the driver's internal `PartitionKeyRangeCache`. When `force_refresh` @@ -2178,15 +2449,17 @@ mod tests { ); } - /// Compile-time assertion that the `execute_operation` future is `Send`. + /// Compile-time assertion that functions are send. /// /// This function is never called; it only needs to compile. - /// If the future returned by `execute_operation` is not `Send`, compilation will fail. #[allow(dead_code, unreachable_code, unused_variables)] - fn _assert_execute_operation_future_is_send() { + fn _assert_functions_are_send() { fn assert_send(_: T) {} let driver: &CosmosDriver = todo!(); assert_send(driver.execute_operation(todo!(), todo!())); + assert_send(driver.execute_point_operation(todo!(), todo!())); + assert_send(driver.execute_plan(todo!(), todo!(), todo!())); + assert_send(driver.plan_operation(todo!(), todo!(), todo!())); } // Account properties with two readable locations for regional fallback tests. @@ -2216,6 +2489,43 @@ mod tests { Arc::new(serde_json::from_str(MULTI_REGION_ACCOUNT_PROPERTIES).unwrap()) } + #[test] + fn partition_key_range_override_does_not_set_feed_range() { + let overrides = request_target_overrides( + RequestTarget::PartitionKeyRange { + range: crate::models::FeedRange::new( + EffectivePartitionKey::from("10"), + EffectivePartitionKey::from("20"), + ), + partition_key_range_id: "7".to_string(), + }, + Some("ct".to_string()), + ); + + assert_eq!(overrides.partition_key_range_id.as_deref(), Some("7")); + assert_eq!(overrides.continuation.as_deref(), Some("ct")); + assert_eq!(overrides.feed_range, None); + } + + #[test] + fn effective_partition_key_range_override_sets_feed_range() { + let range = crate::models::FeedRange::new( + EffectivePartitionKey::from("10"), + EffectivePartitionKey::from("20"), + ); + let overrides = request_target_overrides( + RequestTarget::EffectivePartitionKeyRange { + range: range.clone(), + partition_key_range_id: "merged".to_string(), + }, + Some("ct".to_string()), + ); + + assert_eq!(overrides.partition_key_range_id.as_deref(), Some("merged")); + assert_eq!(overrides.continuation.as_deref(), Some("ct")); + assert_eq!(overrides.feed_range, Some(range)); + } + #[tokio::test] async fn refresh_falls_back_to_regional_endpoints_when_primary_fails() { // Primary metadata request fails (connection error), then the diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/context.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/context.rs new file mode 100644 index 00000000000..2018fef71a3 --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/context.rs @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//! Execution context plumbed through [`PipelineNode::next_page`] calls. + +use futures::future::BoxFuture; + +use crate::models::{CosmosOperation, CosmosResponse, FeedRange}; + +use super::request::RequestTarget; + +/// Request execution mode for partition routing metadata. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub(crate) enum PartitionRoutingRefresh { + /// Use existing partition routing metadata. + UseCached, + /// Force partition routing metadata to be refreshed before executing. + ForceRefresh, +} + +/// Executes leaf request nodes through the existing operation pipeline. +pub(crate) trait RequestExecutor: Send { + /// Executes a single request node. + fn execute_request<'a>( + &'a mut self, + operation: &'a CosmosOperation, + target: RequestTarget, + partition_routing_refresh: PartitionRoutingRefresh, + continuation: Option, + ) -> BoxFuture<'a, azure_core::Result>; +} + +/// Resolves EPK ranges to their current physical partition key ranges. +/// +/// Used by pipeline nodes to recover from partition topology changes (splits) +/// and by the planner to resolve initial query ranges. +/// The `PartitionKeyRangeCache` implements this trait in production. +pub(crate) trait TopologyProvider: Send { + /// Resolves the physical partitions that currently cover the given EPK range. + /// + /// `refresh` controls whether the topology cache is refreshed before resolving: + /// callers use [`PartitionRoutingRefresh::ForceRefresh`] for split recovery + /// and [`PartitionRoutingRefresh::UseCached`] for planning. + /// + /// Returns partition key range IDs paired with their EPK sub-ranges, ordered + /// by EPK from smallest to largest. + fn resolve_ranges<'a>( + &'a mut self, + range: &'a FeedRange, + refresh: PartitionRoutingRefresh, + ) -> BoxFuture<'a, azure_core::Result>>; +} + +/// A physical partition's EPK sub-range, as resolved from the current topology. +#[derive(Debug, Clone)] +pub(crate) struct ResolvedRange { + /// The partition key range ID for this physical partition. + pub partition_key_range_id: String, + /// The EPK sub-range within this physical partition. + pub range: FeedRange, +} + +/// Context passed through dataflow node execution. +pub(crate) struct PipelineContext<'a> { + request_executor: &'a mut dyn RequestExecutor, + topology_provider: Option<&'a mut dyn TopologyProvider>, +} + +impl<'a> PipelineContext<'a> { + /// Creates a new pipeline execution context. + /// + /// `topology_provider` is `None` for plans that cannot need topology + /// resolution (e.g. non-partitioned resource operations). If a node calls + /// [`resolve_ranges`](Self::resolve_ranges) while it is `None`, an error + /// is returned. + pub(crate) fn new( + request_executor: &'a mut dyn RequestExecutor, + topology_provider: Option<&'a mut dyn TopologyProvider>, + ) -> Self { + Self { + request_executor, + topology_provider, + } + } + + pub(crate) async fn execute_request( + &mut self, + operation: &CosmosOperation, + target: RequestTarget, + partition_routing_refresh: PartitionRoutingRefresh, + continuation: Option, + ) -> azure_core::Result { + self.request_executor + .execute_request(operation, target, partition_routing_refresh, continuation) + .await + } + + pub(crate) async fn resolve_ranges( + &mut self, + range: &FeedRange, + refresh: PartitionRoutingRefresh, + ) -> azure_core::Result> { + let provider = self.topology_provider.as_deref_mut().ok_or_else(|| { + azure_core::Error::with_message( + azure_core::error::ErrorKind::Other, + "topology resolution requested for a plan that was not given a topology provider", + ) + })?; + provider.resolve_ranges(range, refresh).await + } +} diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/drain.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/drain.rs new file mode 100644 index 00000000000..60c2af982e2 --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/drain.rs @@ -0,0 +1,686 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//! Sequential drain node for cross-partition feed operations. +//! +//! `SequentialDrain` iterates its children in EPK order (left to right), +//! fully draining one child before advancing to the next. When a child +//! signals a partition split via [`PageResult::SplitRequired`], the drain +//! splices replacement nodes into its children list and retries. + +use std::collections::VecDeque; + +use async_trait::async_trait; + +use crate::models::FeedRange; + +use super::{PageResult, PipelineContext, PipelineNode, PipelineNodeState}; + +/// Maximum number of consecutive split retries before giving up. +/// +/// In practice a split produces 2–3 new ranges. This limit prevents infinite +/// loops if the topology provider keeps returning splits. +const MAX_SPLIT_RETRIES: usize = 10; + +/// Drains child nodes sequentially in EPK order. +/// +/// Each call to `next_page` returns the next page from the left-most (lowest EPK) +/// child. When that child is drained, it is removed and the next child becomes active. +/// When all children are drained, the node itself is drained. +pub(crate) struct SequentialDrain { + children: VecDeque>, +} + +impl SequentialDrain { + /// Creates a new sequential drain over the given children. + /// + /// Children must be ordered by EPK range from smallest to largest. + pub(crate) fn new(children: Vec>) -> Self { + Self { + children: children.into(), + } + } +} + +#[async_trait] +impl PipelineNode for SequentialDrain { + async fn next_page( + &mut self, + context: &mut PipelineContext<'_>, + ) -> azure_core::Result { + let mut split_retries = 0; + + loop { + let Some(current) = self.children.front_mut() else { + return Ok(PageResult::Drained); + }; + + match current.next_page(context).await? { + PageResult::Page { + response, + is_terminal, + } => { + if is_terminal { + // The front child has emitted its last page; evict it + // now so a snapshot taken after this call no longer + // references it. The drain itself is terminal only + // when this was its last child. + self.children.pop_front(); + return Ok(PageResult::Page { + response, + is_terminal: self.children.is_empty(), + }); + } + return Ok(PageResult::Page { + response, + is_terminal: false, + }); + } + PageResult::Drained => { + self.children.pop_front(); + // Loop to try the next child. + } + PageResult::SplitRequired { replacement_nodes } => { + split_retries += 1; + if split_retries > MAX_SPLIT_RETRIES { + // This should be ridiculously rare. + // The topology provider already waits for splits to converge before returning. + return Err(azure_core::Error::with_message( + azure_core::error::ErrorKind::Other, + format!( + "exceeded maximum split retries ({MAX_SPLIT_RETRIES}) \ + in SequentialDrain" + ), + )); + } + + // Remove the split child and splice in replacements at the front. + self.children.pop_front(); + for (i, node) in replacement_nodes.into_iter().enumerate() { + self.children.insert(i, node); + } + // Loop to drain the first replacement. + } + } + } + } + + #[cfg(test)] + fn into_children(self) -> Vec> { + self.children.into_iter().collect() + } + + fn snapshot_state(&self) -> PipelineNodeState { + let Some(front) = self.children.front() else { + return PipelineNodeState::Drained; + }; + let Some(range) = front.feed_range() else { + // Shouldn't happen for an EPK-ordered drain, but degrade gracefully: + // serialize the child snapshot directly with no cursor. + return front.snapshot_state(); + }; + PipelineNodeState::SequentialDrain { + current_min_epk: range.min_inclusive().as_str().to_string(), + left_most: Box::new(front.snapshot_state()), + } + } + + fn feed_range(&self) -> Option<&FeedRange> { + self.children.front().and_then(|c| c.feed_range()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::driver::dataflow::mocks::*; + use crate::models::effective_partition_key::EffectivePartitionKey; + + #[tokio::test] + async fn drains_single_child() { + let child = MockLeaf::with_pages(vec![ + Ok(PageResult::Page { + response: response(b"p1"), + is_terminal: false, + }), + Ok(PageResult::Page { + response: response(b"p2"), + is_terminal: false, + }), + Ok(PageResult::Drained), + ]); + let mut drain = SequentialDrain::new(vec![Box::new(child)]); + let mut executor = NoopRequestExecutor; + let mut topology = NoopTopologyProvider; + let mut context = PipelineContext::new(&mut executor, Some(&mut topology)); + + assert_eq!( + unwrap_page(drain.next_page(&mut context).await).body(), + b"p1" + ); + assert_eq!( + unwrap_page(drain.next_page(&mut context).await).body(), + b"p2" + ); + assert_drained(drain.next_page(&mut context).await); + } + + #[tokio::test] + async fn drains_multiple_children_in_order() { + let child1 = MockLeaf::with_pages(vec![ + Ok(PageResult::Page { + response: response(b"c1-p1"), + is_terminal: false, + }), + Ok(PageResult::Drained), + ]); + let child2 = MockLeaf::with_pages(vec![ + Ok(PageResult::Page { + response: response(b"c2-p1"), + is_terminal: false, + }), + Ok(PageResult::Page { + response: response(b"c2-p2"), + is_terminal: false, + }), + Ok(PageResult::Drained), + ]); + let child3 = MockLeaf::with_pages(vec![ + Ok(PageResult::Page { + response: response(b"c3-p1"), + is_terminal: false, + }), + Ok(PageResult::Drained), + ]); + let mut drain = + SequentialDrain::new(vec![Box::new(child1), Box::new(child2), Box::new(child3)]); + let mut executor = NoopRequestExecutor; + let mut topology = NoopTopologyProvider; + let mut context = PipelineContext::new(&mut executor, Some(&mut topology)); + + assert_eq!( + unwrap_page(drain.next_page(&mut context).await).body(), + b"c1-p1" + ); + assert_eq!( + unwrap_page(drain.next_page(&mut context).await).body(), + b"c2-p1" + ); + assert_eq!( + unwrap_page(drain.next_page(&mut context).await).body(), + b"c2-p2" + ); + assert_eq!( + unwrap_page(drain.next_page(&mut context).await).body(), + b"c3-p1" + ); + assert_drained(drain.next_page(&mut context).await); + } + + #[tokio::test] + async fn empty_drain_is_immediately_drained() { + let mut drain = SequentialDrain::new(vec![]); + let mut executor = NoopRequestExecutor; + let mut topology = NoopTopologyProvider; + let mut context = PipelineContext::new(&mut executor, Some(&mut topology)); + + assert_drained(drain.next_page(&mut context).await); + } + + #[tokio::test] + async fn propagates_child_error() { + let child = MockLeaf::with_pages(vec![Err(azure_core::Error::with_message( + azure_core::error::ErrorKind::Other, + "test error", + ))]); + let mut drain = SequentialDrain::new(vec![Box::new(child)]); + let mut executor = NoopRequestExecutor; + let mut topology = NoopTopologyProvider; + let mut context = PipelineContext::new(&mut executor, Some(&mut topology)); + + let err = drain.next_page(&mut context).await.unwrap_err(); + assert_eq!(err.to_string(), "test error"); + } + + #[tokio::test] + async fn handles_split_of_first_child() { + let replacement1 = MockLeaf::with_pages(vec![ + Ok(PageResult::Page { + response: response(b"split-left"), + is_terminal: false, + }), + Ok(PageResult::Drained), + ]); + let replacement2 = MockLeaf::with_pages(vec![ + Ok(PageResult::Page { + response: response(b"split-right"), + is_terminal: false, + }), + Ok(PageResult::Drained), + ]); + + let split_child = MockLeaf::with_pages(vec![Ok(PageResult::SplitRequired { + replacement_nodes: vec![Box::new(replacement1), Box::new(replacement2)], + })]); + + let trailing_child = MockLeaf::with_pages(vec![ + Ok(PageResult::Page { + response: response(b"trailing"), + is_terminal: false, + }), + Ok(PageResult::Drained), + ]); + + let mut drain = SequentialDrain::new(vec![Box::new(split_child), Box::new(trailing_child)]); + let mut executor = NoopRequestExecutor; + let mut topology = NoopTopologyProvider; + let mut context = PipelineContext::new(&mut executor, Some(&mut topology)); + + assert_eq!( + unwrap_page(drain.next_page(&mut context).await).body(), + b"split-left" + ); + assert_eq!( + unwrap_page(drain.next_page(&mut context).await).body(), + b"split-right" + ); + assert_eq!( + unwrap_page(drain.next_page(&mut context).await).body(), + b"trailing" + ); + assert_drained(drain.next_page(&mut context).await); + } + + #[tokio::test] + async fn handles_split_of_middle_child() { + let child1 = MockLeaf::with_pages(vec![ + Ok(PageResult::Page { + response: response(b"c1"), + is_terminal: false, + }), + Ok(PageResult::Drained), + ]); + + let replacement = MockLeaf::with_pages(vec![ + Ok(PageResult::Page { + response: response(b"c2-split"), + is_terminal: false, + }), + Ok(PageResult::Drained), + ]); + let split_child = MockLeaf::with_pages(vec![Ok(PageResult::SplitRequired { + replacement_nodes: vec![Box::new(replacement)], + })]); + + let child3 = MockLeaf::with_pages(vec![ + Ok(PageResult::Page { + response: response(b"c3"), + is_terminal: false, + }), + Ok(PageResult::Drained), + ]); + + let mut drain = SequentialDrain::new(vec![ + Box::new(child1), + Box::new(split_child), + Box::new(child3), + ]); + let mut executor = NoopRequestExecutor; + let mut topology = NoopTopologyProvider; + let mut context = PipelineContext::new(&mut executor, Some(&mut topology)); + + assert_eq!( + unwrap_page(drain.next_page(&mut context).await).body(), + b"c1" + ); + assert_eq!( + unwrap_page(drain.next_page(&mut context).await).body(), + b"c2-split" + ); + assert_eq!( + unwrap_page(drain.next_page(&mut context).await).body(), + b"c3" + ); + assert_drained(drain.next_page(&mut context).await); + } + + #[tokio::test] + async fn handles_split_of_last_child() { + let child1 = MockLeaf::with_pages(vec![ + Ok(PageResult::Page { + response: response(b"c1"), + is_terminal: false, + }), + Ok(PageResult::Drained), + ]); + + let replacement = MockLeaf::with_pages(vec![ + Ok(PageResult::Page { + response: response(b"last-split"), + is_terminal: false, + }), + Ok(PageResult::Drained), + ]); + let split_child = MockLeaf::with_pages(vec![Ok(PageResult::SplitRequired { + replacement_nodes: vec![Box::new(replacement)], + })]); + + let mut drain = SequentialDrain::new(vec![Box::new(child1), Box::new(split_child)]); + let mut executor = NoopRequestExecutor; + let mut topology = NoopTopologyProvider; + let mut context = PipelineContext::new(&mut executor, Some(&mut topology)); + + assert_eq!( + unwrap_page(drain.next_page(&mut context).await).body(), + b"c1" + ); + assert_eq!( + unwrap_page(drain.next_page(&mut context).await).body(), + b"last-split" + ); + assert_drained(drain.next_page(&mut context).await); + } + + #[tokio::test] + async fn handles_cascading_split() { + let final_leaf = MockLeaf::with_pages(vec![ + Ok(PageResult::Page { + response: response(b"final"), + is_terminal: false, + }), + Ok(PageResult::Drained), + ]); + + let cascading_replacement = MockLeaf::with_pages(vec![Ok(PageResult::SplitRequired { + replacement_nodes: vec![Box::new(final_leaf)], + })]); + + let initial_split = MockLeaf::with_pages(vec![Ok(PageResult::SplitRequired { + replacement_nodes: vec![Box::new(cascading_replacement)], + })]); + + let mut drain = SequentialDrain::new(vec![Box::new(initial_split)]); + let mut executor = NoopRequestExecutor; + let mut topology = NoopTopologyProvider; + let mut context = PipelineContext::new(&mut executor, Some(&mut topology)); + + assert_eq!( + unwrap_page(drain.next_page(&mut context).await).body(), + b"final" + ); + assert_drained(drain.next_page(&mut context).await); + } + + #[tokio::test] + async fn split_retry_limit_prevents_infinite_loop() { + let mut current: Box = + Box::new(MockLeaf::with_pages(vec![Ok(PageResult::Page { + response: response(b"unreachable"), + is_terminal: false, + })])); + + for _ in 0..12 { + current = Box::new(MockLeaf::with_pages(vec![Ok(PageResult::SplitRequired { + replacement_nodes: vec![current], + })])); + } + + let mut drain = SequentialDrain::new(vec![current]); + let mut executor = NoopRequestExecutor; + let mut topology = NoopTopologyProvider; + let mut context = PipelineContext::new(&mut executor, Some(&mut topology)); + + let err = drain.next_page(&mut context).await.unwrap_err(); + assert_eq!( + err.to_string(), + "exceeded maximum split retries (10) in SequentialDrain" + ); + } + + #[tokio::test] + async fn child_drained_immediately_skips_to_next() { + let empty_child = MockLeaf::with_pages(vec![Ok(PageResult::Drained)]); + let real_child = MockLeaf::with_pages(vec![ + Ok(PageResult::Page { + response: response(b"data"), + is_terminal: false, + }), + Ok(PageResult::Drained), + ]); + + let mut drain = SequentialDrain::new(vec![Box::new(empty_child), Box::new(real_child)]); + let mut executor = NoopRequestExecutor; + let mut topology = NoopTopologyProvider; + let mut context = PipelineContext::new(&mut executor, Some(&mut topology)); + + assert_eq!( + unwrap_page(drain.next_page(&mut context).await).body(), + b"data" + ); + assert_drained(drain.next_page(&mut context).await); + } + + #[tokio::test] + async fn split_with_three_way_replacement() { + let r1 = MockLeaf::with_pages(vec![ + Ok(PageResult::Page { + response: response(b"r1"), + is_terminal: false, + }), + Ok(PageResult::Drained), + ]); + let r2 = MockLeaf::with_pages(vec![ + Ok(PageResult::Page { + response: response(b"r2"), + is_terminal: false, + }), + Ok(PageResult::Drained), + ]); + let r3 = MockLeaf::with_pages(vec![ + Ok(PageResult::Page { + response: response(b"r3"), + is_terminal: false, + }), + Ok(PageResult::Drained), + ]); + + let split_child = MockLeaf::with_pages(vec![Ok(PageResult::SplitRequired { + replacement_nodes: vec![Box::new(r1), Box::new(r2), Box::new(r3)], + })]); + + let mut drain = SequentialDrain::new(vec![Box::new(split_child)]); + let mut executor = NoopRequestExecutor; + let mut topology = NoopTopologyProvider; + let mut context = PipelineContext::new(&mut executor, Some(&mut topology)); + + assert_eq!( + unwrap_page(drain.next_page(&mut context).await).body(), + b"r1" + ); + assert_eq!( + unwrap_page(drain.next_page(&mut context).await).body(), + b"r2" + ); + assert_eq!( + unwrap_page(drain.next_page(&mut context).await).body(), + b"r3" + ); + assert_drained(drain.next_page(&mut context).await); + } + + #[tokio::test] + async fn error_after_partial_drain() { + let child1 = MockLeaf::with_pages(vec![ + Ok(PageResult::Page { + response: response(b"ok"), + is_terminal: false, + }), + Ok(PageResult::Drained), + ]); + let child2 = MockLeaf::with_pages(vec![Err(azure_core::Error::with_message( + azure_core::error::ErrorKind::Other, + "boom", + ))]); + + let mut drain = SequentialDrain::new(vec![Box::new(child1), Box::new(child2)]); + let mut executor = NoopRequestExecutor; + let mut topology = NoopTopologyProvider; + let mut context = PipelineContext::new(&mut executor, Some(&mut topology)); + + assert_eq!( + unwrap_page(drain.next_page(&mut context).await).body(), + b"ok" + ); + let err = drain.next_page(&mut context).await.unwrap_err(); + assert_eq!(err.to_string(), "boom"); + } + + #[tokio::test] + async fn multiple_pages_per_child_then_advance() { + let child1 = MockLeaf::with_pages(vec![ + Ok(PageResult::Page { + response: response(b"c1-p1"), + is_terminal: false, + }), + Ok(PageResult::Page { + response: response(b"c1-p2"), + is_terminal: false, + }), + Ok(PageResult::Page { + response: response(b"c1-p3"), + is_terminal: false, + }), + Ok(PageResult::Drained), + ]); + let child2 = MockLeaf::with_pages(vec![ + Ok(PageResult::Page { + response: response(b"c2-p1"), + is_terminal: false, + }), + Ok(PageResult::Drained), + ]); + + let mut drain = SequentialDrain::new(vec![Box::new(child1), Box::new(child2)]); + let mut executor = NoopRequestExecutor; + let mut topology = NoopTopologyProvider; + let mut context = PipelineContext::new(&mut executor, Some(&mut topology)); + + assert_eq!( + unwrap_page(drain.next_page(&mut context).await).body(), + b"c1-p1" + ); + assert_eq!( + unwrap_page(drain.next_page(&mut context).await).body(), + b"c1-p2" + ); + assert_eq!( + unwrap_page(drain.next_page(&mut context).await).body(), + b"c1-p3" + ); + assert_eq!( + unwrap_page(drain.next_page(&mut context).await).body(), + b"c2-p1" + ); + assert_drained(drain.next_page(&mut context).await); + } + + #[tokio::test] + async fn split_produces_page_on_same_call() { + let replacement = MockLeaf::with_pages(vec![ + Ok(PageResult::Page { + response: response(b"immediate"), + is_terminal: false, + }), + Ok(PageResult::Drained), + ]); + + let split_child = MockLeaf::with_pages(vec![Ok(PageResult::SplitRequired { + replacement_nodes: vec![Box::new(replacement)], + })]); + + let mut drain = SequentialDrain::new(vec![Box::new(split_child)]); + let mut executor = NoopRequestExecutor; + let mut topology = NoopTopologyProvider; + let mut context = PipelineContext::new(&mut executor, Some(&mut topology)); + + assert_eq!( + unwrap_page(drain.next_page(&mut context).await).body(), + b"immediate" + ); + assert_drained(drain.next_page(&mut context).await); + } + + #[tokio::test] + async fn terminal_page_pops_child_eagerly() { + // The first child returns one terminal page; the drain must pop it + // immediately so a snapshot taken right after the call already + // points at the next child. + let child1 = MockLeaf::with_pages(vec![Ok(PageResult::Page { + response: response(b"c1-final"), + is_terminal: true, + })]) + .with_feed_range(FeedRange::new( + EffectivePartitionKey::from("00"), + EffectivePartitionKey::from("80"), + )); + let child2 = MockLeaf::with_pages(vec![ + Ok(PageResult::Page { + response: response(b"c2-p1"), + is_terminal: false, + }), + Ok(PageResult::Drained), + ]) + .with_feed_range(FeedRange::new( + EffectivePartitionKey::from("80"), + EffectivePartitionKey::from("FF"), + )); + + let mut drain = SequentialDrain::new(vec![Box::new(child1), Box::new(child2)]); + let mut executor = NoopRequestExecutor; + let mut topology = NoopTopologyProvider; + let mut context = PipelineContext::new(&mut executor, Some(&mut topology)); + + let page = unwrap_page(drain.next_page(&mut context).await); + assert_eq!(page.body(), b"c1-final"); + + // Snapshot must already reference child2 (cursor at "80"), not the + // just-drained child1. + let snapshot = drain.snapshot_state(); + let PipelineNodeState::SequentialDrain { + current_min_epk, .. + } = snapshot + else { + panic!("expected SequentialDrain snapshot, got {snapshot:?}"); + }; + assert_eq!(current_min_epk, "80"); + } + + #[tokio::test] + async fn terminal_page_on_last_child_marks_drain_terminal() { + let only_child = MockLeaf::with_pages(vec![Ok(PageResult::Page { + response: response(b"final"), + is_terminal: true, + })]) + .with_feed_range(FeedRange::new( + EffectivePartitionKey::from("00"), + EffectivePartitionKey::from("FF"), + )); + + let mut drain = SequentialDrain::new(vec![Box::new(only_child)]); + let mut executor = NoopRequestExecutor; + let mut topology = NoopTopologyProvider; + let mut context = PipelineContext::new(&mut executor, Some(&mut topology)); + + match drain.next_page(&mut context).await.unwrap() { + PageResult::Page { + response, + is_terminal, + } => { + assert_eq!(response.body(), b"final"); + assert!(is_terminal, "drain must propagate terminal flag"); + } + other => panic!("expected Page, got {other:?}"), + } + assert!(matches!(drain.snapshot_state(), PipelineNodeState::Drained)); + } +} diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/drained.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/drained.rs new file mode 100644 index 00000000000..ac21f45bbbe --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/drained.rs @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//! A trivial leaf node that immediately reports `Drained`. +//! +//! Used when reconstructing a pipeline from a continuation token whose +//! [`PipelineNodeState::Drained`](super::PipelineNodeState) snapshot indicates +//! the operation already completed. Allows the SDK iterator to behave +//! uniformly without the planner having to special-case the "already done" +//! state. + +use async_trait::async_trait; + +use super::{PageResult, PipelineContext, PipelineNode, PipelineNodeState}; + +pub(crate) struct DrainedLeaf; + +#[async_trait] +impl PipelineNode for DrainedLeaf { + async fn next_page( + &mut self, + _context: &mut PipelineContext<'_>, + ) -> azure_core::Result { + Ok(PageResult::Drained) + } + + #[cfg(test)] + fn into_children(self) -> Vec> { + Vec::new() + } + + fn snapshot_state(&self) -> PipelineNodeState { + PipelineNodeState::Drained + } +} diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/mocks.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/mocks.rs new file mode 100644 index 00000000000..6f7019b173c --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/mocks.rs @@ -0,0 +1,269 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//! Shared test mocks for dataflow pipeline testing. + +use std::{collections::VecDeque, sync::Arc}; + +use azure_core::http::StatusCode; +use futures::future::BoxFuture; + +use super::{ + PageResult, PartitionRoutingRefresh, PipelineContext, PipelineNode, PipelineNodeState, + RequestExecutor, RequestTarget, ResolvedRange, TopologyProvider, +}; +use crate::{ + diagnostics::DiagnosticsContextBuilder, + models::{ + effective_partition_key::EffectivePartitionKey, AccountReference, ActivityId, + CosmosOperation, CosmosResponse, CosmosResponseHeaders, CosmosStatus, DatabaseReference, + FeedRange, PartitionKey, SubStatusCode, + }, + options::DiagnosticsOptions, +}; + +// ── Mock pipeline node ────────────────────────────────────────────────────── + +/// A mock leaf node that returns pre-configured page results. +pub(crate) struct MockLeaf { + pages: VecDeque>, + feed_range: Option, +} + +impl MockLeaf { + /// Creates a mock leaf with a sequence of results to return from `next_page`. + pub fn with_pages(pages: Vec>) -> Self { + Self { + pages: pages.into(), + feed_range: None, + } + } + + /// Sets the feed range reported by [`PipelineNode::feed_range`]. + #[allow(dead_code)] + pub fn with_feed_range(mut self, range: FeedRange) -> Self { + self.feed_range = Some(range); + self + } +} + +#[async_trait::async_trait] +impl PipelineNode for MockLeaf { + async fn next_page( + &mut self, + _context: &mut PipelineContext<'_>, + ) -> azure_core::Result { + self.pages + .pop_front() + .expect("MockLeaf: no more page results") + } + + #[cfg(test)] + fn into_children(self) -> Vec> { + vec![] + } + + fn snapshot_state(&self) -> PipelineNodeState { + PipelineNodeState::Drained + } + + fn feed_range(&self) -> Option<&FeedRange> { + self.feed_range.as_ref() + } +} + +// ── Request executors ─────────────────────────────────────────────────────── + +/// A request executor that should never be called. +pub(crate) struct NoopRequestExecutor; + +impl RequestExecutor for NoopRequestExecutor { + fn execute_request<'a>( + &'a mut self, + _operation: &'a CosmosOperation, + _target: RequestTarget, + _partition_routing_refresh: PartitionRoutingRefresh, + _continuation: Option, + ) -> BoxFuture<'a, azure_core::Result> { + Box::pin(async { + Err(azure_core::Error::with_message( + azure_core::error::ErrorKind::Other, + "noop executor should not be called", + )) + }) + } +} + +/// A mock request executor that records calls and returns pre-configured responses. +pub(crate) struct MockRequestExecutor { + pub responses: VecDeque>, + pub refresh_calls: Vec, + pub continuation_calls: Vec>, +} + +impl MockRequestExecutor { + pub fn new(responses: Vec>) -> Self { + Self { + responses: responses.into(), + refresh_calls: Vec::new(), + continuation_calls: Vec::new(), + } + } +} + +impl RequestExecutor for MockRequestExecutor { + fn execute_request<'a>( + &'a mut self, + _operation: &'a CosmosOperation, + _target: RequestTarget, + partition_routing_refresh: PartitionRoutingRefresh, + continuation: Option, + ) -> BoxFuture<'a, azure_core::Result> { + self.refresh_calls.push(partition_routing_refresh); + self.continuation_calls.push(continuation); + let response = self.responses.pop_front().expect("mock request response"); + Box::pin(async move { response }) + } +} + +// ── Topology providers ───────────────────────────────────────────────────── + +/// A topology provider that should never be called. +pub(crate) struct NoopTopologyProvider; + +impl TopologyProvider for NoopTopologyProvider { + fn resolve_ranges<'a>( + &'a mut self, + _range: &'a FeedRange, + _refresh: PartitionRoutingRefresh, + ) -> BoxFuture<'a, azure_core::Result>> { + Box::pin(async { + Err(azure_core::Error::with_message( + azure_core::error::ErrorKind::Other, + "noop topology provider should not be called", + )) + }) + } +} + +/// A mock topology provider that returns pre-configured resolved ranges. +pub(crate) struct MockTopologyProvider { + results: VecDeque>>, +} + +impl MockTopologyProvider { + pub fn new(results: Vec>>) -> Self { + Self { + results: results.into(), + } + } +} + +impl TopologyProvider for MockTopologyProvider { + fn resolve_ranges<'a>( + &'a mut self, + _range: &'a FeedRange, + _refresh: PartitionRoutingRefresh, + ) -> BoxFuture<'a, azure_core::Result>> { + let result = self + .results + .pop_front() + .expect("MockTopologyProvider: no more results"); + Box::pin(async move { result }) + } +} + +// ── Test helpers ──────────────────────────────────────────────────────────── + +/// Extracts the `CosmosResponse` from a `PageResult::Page`, panicking otherwise. +pub(crate) fn unwrap_page(result: azure_core::Result) -> CosmosResponse { + match result.expect("expected Ok result") { + PageResult::Page { response, .. } => response, + PageResult::Drained => panic!("expected Page, got Drained"), + PageResult::SplitRequired { .. } => panic!("expected Page, got SplitRequired"), + } +} + +/// Asserts that a `PageResult` is `Drained`. +pub(crate) fn assert_drained(result: azure_core::Result) { + match result.expect("expected Ok result") { + PageResult::Drained => {} + PageResult::Page { .. } => panic!("expected Drained, got Page"), + PageResult::SplitRequired { .. } => panic!("expected Drained, got SplitRequired"), + } +} + +/// Creates a test `CosmosOperation`. +pub(crate) fn operation() -> CosmosOperation { + let account = AccountReference::with_master_key( + url::Url::parse("https://test.documents.azure.com:443/").unwrap(), + "dGVzdA==", + ); + let database = DatabaseReference::from_name(account, "db".to_owned()); + CosmosOperation::read_database(database) +} + +/// Creates a `RequestTarget` for a logical partition key. +pub(crate) fn logical_partition_target() -> RequestTarget { + RequestTarget::LogicalPartitionKey(PartitionKey::from("pk")) +} + +/// Creates a `RequestTarget` for an EPK range ("" to "80", partition key range ID "0"). +pub(crate) fn epk_range_target() -> RequestTarget { + RequestTarget::PartitionKeyRange { + range: FeedRange::new( + EffectivePartitionKey::min(), + EffectivePartitionKey::from("80"), + ), + partition_key_range_id: "0".to_string(), + } +} + +/// Creates a test response with the given body. +pub(crate) fn response(body: &[u8]) -> CosmosResponse { + response_with_continuation(body, None) +} + +/// Creates a test response with the given body and optional continuation token. +pub(crate) fn response_with_continuation( + body: &[u8], + continuation: Option<&str>, +) -> CosmosResponse { + let mut diagnostics = DiagnosticsContextBuilder::new( + ActivityId::new_uuid(), + Arc::new(DiagnosticsOptions::default()), + ); + diagnostics.set_operation_status(StatusCode::Ok, None); + let mut headers = CosmosResponseHeaders::new(); + headers.continuation = continuation.map(str::to_owned); + CosmosResponse::new( + body.to_vec(), + headers, + CosmosStatus::new(StatusCode::Ok), + Arc::new(diagnostics.complete()), + ) +} + +/// Creates a 410 Gone error with a partition topology change substatus. +pub(crate) fn gone_error() -> azure_core::Error { + azure_core::Error::new( + azure_core::error::ErrorKind::HttpResponse { + status: StatusCode::Gone, + error_code: Some(SubStatusCode::PARTITION_KEY_RANGE_GONE.value().to_string()), + raw_response: None, + }, + "partition topology changed", + ) +} + +/// Creates a 410 Gone error with a non-topology substatus. +pub(crate) fn non_topology_gone_error() -> azure_core::Error { + azure_core::Error::new( + azure_core::error::ErrorKind::HttpResponse { + status: StatusCode::Gone, + error_code: Some(SubStatusCode::NAME_CACHE_STALE.value().to_string()), + raw_response: None, + }, + "name cache is stale", + ) +} diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/mod.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/mod.rs new file mode 100644 index 00000000000..83e2dda65cd --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/mod.rs @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//! Dataflow pipeline nodes for paged Cosmos DB operations. +//! +//! Everything in this module is driver-internal except [`OperationPlan`], +//! which is the only type re-exported to public APIs. The rest is the +//! machinery `CosmosDriver` uses to plan, execute, and resume paged +//! operations. +//! +//! # Navigation map +//! +//! - Leaf nodes: [`Request`] (executes a single Cosmos DB request and pages +//! through continuation tokens) and [`DrainedLeaf`] (a no-op leaf used when +//! resuming an already-completed plan). +//! - Intermediate nodes: [`SequentialDrain`] iterates EPK-ordered children +//! left-to-right, draining each before advancing. +//! - Planner: [`planner::build_trivial_pipeline`] handles point reads and +//! single-partition operations; [`planner::build_sequential_drain`] handles +//! cross-partition queries by consuming a backend query plan and resolving +//! it against the current topology. +//! - Serializable state: [`PipelineNodeState`] (see [`snapshot`]) is the +//! in-memory shape of a continuation snapshot; the wire-format token lives +//! in [`crate::models::ContinuationToken`]. +//! - Topology adapter: [`CachedTopologyProvider`] backs the +//! [`TopologyProvider`] trait with the driver's +//! [`PartitionKeyRangeCache`](crate::driver::cache::PartitionKeyRangeCache). +//! +//! See `FEED_OPERATIONS_REQS.md` for the design intent behind the dataflow +//! pipeline (paged operations, split recovery, continuation tokens, planned +//! cross-partition strategies). + +mod context; +mod drain; +mod drained; +#[cfg(test)] +pub(crate) mod mocks; +mod node; +mod pipeline; +pub(crate) mod planner; +pub(crate) mod query_plan; +mod request; +mod snapshot; +mod topology; + +pub(crate) use context::{ + PartitionRoutingRefresh, PipelineContext, RequestExecutor, ResolvedRange, TopologyProvider, +}; +pub(crate) use drain::SequentialDrain; +pub(crate) use drained::DrainedLeaf; +pub(crate) use node::{PageResult, PipelineNode}; +pub use pipeline::OperationPlan; +pub(crate) use pipeline::Pipeline; +pub(crate) use request::{Request, RequestTarget}; +pub(crate) use snapshot::PipelineNodeState; +pub(crate) use topology::CachedTopologyProvider; + +#[cfg(test)] +mod tests { + use super::mocks::*; + use super::*; + + #[tokio::test] + async fn pipeline_forwards_pages_from_root() { + let mut pipeline = + Pipeline::new(Box::new(MockLeaf::with_pages(vec![Ok(PageResult::Page { + response: response(b"page"), + is_terminal: false, + })]))); + let mut executor = NoopRequestExecutor; + let mut topology = NoopTopologyProvider; + let mut context = PipelineContext::new(&mut executor, Some(&mut topology)); + + let page = pipeline.next_page(&mut context).await.unwrap().unwrap(); + + assert_eq!(page.body(), b"page"); + } +} diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/node.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/node.rs new file mode 100644 index 00000000000..e2203f7d2ad --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/node.rs @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//! [`PipelineNode`] trait and [`PageResult`] returned from each pull. + +use async_trait::async_trait; + +use crate::models::{CosmosResponse, FeedRange}; + +use super::{context::PipelineContext, snapshot::PipelineNodeState}; + +/// Result of a single `next_page` call on a pipeline node. +/// +/// The `Page` variant contains a large `CosmosResponse` inline, but boxing it +/// would add a heap allocation on every page fetch — the hot path. The `SplitRequired` +/// variant is rare (only on partition splits), so the size difference is acceptable. +#[must_use = "a PageResult carries the next page, drain signal, or a split request that the caller must act on"] +#[allow(clippy::large_enum_variant)] +pub(crate) enum PageResult { + /// A page of results was produced. + /// + /// `is_terminal` is `true` when this node has no more pages to emit + /// after this one — set by leaf nodes when the server returned no + /// continuation token, and propagated by intermediate nodes when their + /// last child has emitted its terminal page. Parents use this to evict + /// drained children eagerly so that snapshots of the pipeline do not + /// include children that are already done. + Page { + response: CosmosResponse, + is_terminal: bool, + }, + /// This node has no more pages to emit. + Drained, + /// This node's EPK range has split and needs to be replaced by new child nodes. + /// + /// It is the parent intermediate node's responsibility to splice + /// `replacement_nodes` into its children list (in place of the child that + /// emitted this result) and re-attempt draining from the first replacement. + /// If a node returns `SplitRequired` to a parent that does not handle + /// splits (e.g. the pipeline root), the operation fails. + SplitRequired { + /// New child nodes covering the sub-ranges of the split partition. + replacement_nodes: Vec>, + }, +} + +impl std::fmt::Debug for PageResult { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + PageResult::Page { is_terminal, .. } => { + write!(f, "Page(terminal={is_terminal})") + } + PageResult::Drained => f.write_str("Drained"), + PageResult::SplitRequired { + replacement_nodes, .. + } => write!(f, "SplitRequired({} nodes)", replacement_nodes.len()), + } + } +} + +/// A dataflow node that emits pages and may own child nodes. +/// +/// Each `next_page` call boxes a future via `async_trait`; the per-page +/// allocation is negligible compared to the multi-millisecond network I/O +/// of a Cosmos DB request. +#[async_trait] +pub(crate) trait PipelineNode: Send + std::any::Any { + /// Emits the next page of results, signals drain completion, or requests a split. + async fn next_page( + &mut self, + context: &mut PipelineContext<'_>, + ) -> azure_core::Result; + + /// Consumes this node and returns its children as a `Vec`. + /// + /// Used by tests to inspect the dataflow tree's shape after planning. + #[cfg(test)] + fn into_children(self) -> Vec>; + + /// Snapshots this node's state for continuation-token serialization. + fn snapshot_state(&self) -> PipelineNodeState; + + /// Returns the EPK range this node currently targets, if known. + /// + /// Used by intermediate nodes (e.g. [`super::SequentialDrain`]) to record + /// the current cursor position when snapshotting, without needing to know + /// the concrete type of their children. Defaults to `None`. + /// + /// # Invariant + /// + /// Every node in the dataflow tree is responsible for some contiguous EPK + /// sub-range of the container key space. Intermediate nodes that drain + /// children in EPK order (such as [`super::SequentialDrain`]) may use the + /// front child's `feed_range()` as their own cursor; intermediates that + /// combine results across ranges (e.g. a future k-way merge for streaming + /// `ORDER BY`) are responsible for snapshotting whatever cursor + /// representation makes sense for their ordering semantics. + fn feed_range(&self) -> Option<&FeedRange> { + None + } +} + +#[cfg(test)] +impl dyn PipelineNode { + /// Downcasts this node to a concrete type. + pub(crate) fn downcast_ref(&self) -> Option<&T> { + (self as &dyn std::any::Any).downcast_ref::() + } + + /// Downcasts this node to a concrete type. + pub(crate) fn downcast(self: Box) -> Option> { + (self as Box).downcast::().ok() + } +} diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/pipeline.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/pipeline.rs new file mode 100644 index 00000000000..50f1a75be84 --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/pipeline.rs @@ -0,0 +1,91 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//! [`Pipeline`] (driver-internal) and [`OperationPlan`] (driver-public). + +use crate::models::{ContinuationToken, CosmosResponse}; + +use super::context::PipelineContext; +use super::node::{PageResult, PipelineNode}; +use super::snapshot::PipelineNodeState; + +/// A pipeline root that owns the node tree. +pub(crate) struct Pipeline { + root: Box, +} + +impl std::fmt::Debug for Pipeline { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Pipeline").finish_non_exhaustive() + } +} + +impl Pipeline { + /// Creates a pipeline from an owned root node. + pub(crate) fn new(root: Box) -> Self { + Self { root } + } + + /// Returns a reference to the root node. + #[cfg(test)] + pub(crate) fn root(&self) -> &dyn PipelineNode { + &*self.root + } + + /// Consumes the pipeline and returns the root node. + #[cfg(test)] + pub(crate) fn into_root(self) -> Box { + self.root + } + + /// Emits the next page from the root node. + /// + /// Returns `Ok(Some(response))` for a page, `Ok(None)` when drained. + pub(crate) async fn next_page( + &mut self, + context: &mut PipelineContext<'_>, + ) -> azure_core::Result> { + match self.root.next_page(context).await? { + PageResult::Page { response, .. } => Ok(Some(response)), + PageResult::Drained => Ok(None), + // Defensive: today the root is always a `Request`, `SequentialDrain`, + // or `DrainedLeaf`, none of which can bubble `SplitRequired` up past + // their parent. If a future node type ever does, surfacing it as an + // explicit error is preferable to silently dropping the page. + PageResult::SplitRequired { .. } => Err(azure_core::Error::with_message( + azure_core::error::ErrorKind::Other, + "root node cannot request a split; splits must be handled by a parent node", + )), + } + } + + /// Snapshots the pipeline's current state for continuation-token serialization. + pub(crate) fn snapshot_state(&self) -> PipelineNodeState { + self.root.snapshot_state() + } +} + +/// A plan for executing a Cosmos DB operation. +/// +/// Produced by [`CosmosDriver::plan_operation`](crate::driver::CosmosDriver::plan_operation). +pub struct OperationPlan { + pub(crate) pipeline: Pipeline, +} + +impl OperationPlan { + /// Creates an operation plan wrapping the given pipeline. + pub(crate) fn new(pipeline: Pipeline) -> Self { + Self { pipeline } + } + + /// Snapshots this plan into a [`ContinuationToken`] suitable for cross-process + /// resumption. + /// + /// Snapshotting walks the pipeline tree and serializes a minimal record of + /// each node's progress. The result can be passed back to + /// [`CosmosDriver::plan_operation`](crate::driver::CosmosDriver::plan_operation) + /// (with the same operation) to resume where this plan left off. + pub fn to_continuation_token(&self) -> azure_core::Result { + ContinuationToken::encode_v1(&self.pipeline.snapshot_state()) + } +} diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/planner.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/planner.rs new file mode 100644 index 00000000000..41cae9b97a5 --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/planner.rs @@ -0,0 +1,919 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//! Pipeline planner for Cosmos DB operations. +//! +//! The planner validates an operation's target against its resource type and +//! constructs the appropriate dataflow [`Pipeline`]. +//! +//! For cross-partition queries, [`build_sequential_drain`] consumes a backend +//! [`QueryPlan`](super::query_plan::QueryPlan) and resolves the query's EPK +//! ranges against the current topology to produce a fan-out pipeline. + +use std::sync::Arc; + +use crate::models::{ + effective_partition_key::EffectivePartitionKey, CosmosOperation, FeedRange, OperationTarget, +}; + +use super::{ + query_plan::{QueryInfo, QueryPlan}, + DrainedLeaf, PartitionRoutingRefresh, Pipeline, PipelineNode, PipelineNodeState, Request, + RequestTarget, SequentialDrain, TopologyProvider, +}; + +/// Builds a single-node [`Pipeline`] for a trivial operation. +/// +/// Trivial operations are those that can be satisfied by a single request to +/// one partition (point reads, single-partition queries, metadata operations). +/// Use [`CosmosOperation::is_trivial`] to check eligibility before calling. +/// +/// `operation` is shared with the resulting [`Request`] node via `Arc`; the +/// caller passes ownership in (cheap because the underlying allocation is +/// shared with any other nodes that need the same operation). +/// +/// `resume` is an optional [`PipelineNodeState`] from a continuation token +/// that augments planning. Only `Request` and `Drained` shapes are accepted +/// for trivial operations; any other shape returns a `DataConversion` error. +/// +/// # Panics (debug builds) +/// +/// Debug-asserts that the operation is indeed trivial. In release builds, +/// returns an error if a non-trivial operation (e.g. a cross-partition query) +/// is passed. +pub(crate) fn build_trivial_pipeline( + operation: Arc, + resume: Option, +) -> azure_core::Result { + debug_assert!( + operation.is_trivial(), + "build_trivial_pipeline called with non-trivial operation: {:?} targeting {:?}", + operation.operation_type(), + operation.target(), + ); + + let resource_type = operation.resource_type(); + let target = operation.target(); + + if !resource_type.is_valid_target(target) { + return Err(azure_core::Error::with_message( + azure_core::error::ErrorKind::Other, + format!( + "operation target {target_desc} is not valid for resource type {resource_type}", + target_desc = target_description(target), + ), + )); + } + + let initial_continuation = match resume { + None => None, + Some(PipelineNodeState::Request { + server_continuation, + }) => server_continuation, + Some(PipelineNodeState::Drained) => { + return Ok(Pipeline::new(Box::new(DrainedLeaf))); + } + Some(other) => { + return Err(azure_core::Error::with_message( + azure_core::error::ErrorKind::DataConversion, + format!( + "continuation token shape {} does not match a trivial operation", + snapshot_kind(&other) + ), + )); + } + }; + + let request_target = match target { + OperationTarget::None => RequestTarget::NonPartitioned, + OperationTarget::PartitionKey(pk) => RequestTarget::LogicalPartitionKey(pk.clone()), + OperationTarget::FeedRange(_) => { + return Err(azure_core::Error::with_message( + azure_core::error::ErrorKind::Other, + "FeedRange targeting requires a fan-out pipeline; \ + use plan_operation for cross-partition queries", + )); + } + }; + + let root = Request::new(operation, request_target, initial_continuation); + Ok(Pipeline::new(Box::new(root))) +} + +/// Builds a fan-out [`Pipeline`] from a backend query plan as a sequential drain. +/// +/// Produces either a single [`Request`] leaf (when planning resolves to a +/// single physical partition) or a [`SequentialDrain`] over one `Request` per +/// resolved range. Other cross-partition strategies (streaming `ORDER BY`, +/// hybrid search, read-many, etc.) will live as sibling functions. +/// +/// `operation` is the underlying logical operation shared across every +/// resulting [`Request`] node via `Arc::clone`; per-partition differences +/// (e.g. partition-key-range targeting) are layered on at execution time via +/// [`OperationOverrides`](crate::pipeline::OperationOverrides) and the +/// per-node [`RequestTarget`], not by cloning the operation itself. +/// +/// This function: +/// 1. Validates that the query plan contains no unsupported features (no +/// top/limit, no ordering, no hybrid search, no aggregates). +/// 2. Converts the plan's `queryRanges` to [`FeedRange`]s and resolves them +/// against the current partition topology. +/// 3. Creates a [`Request`] node for each resolved range and bundles them in a +/// [`SequentialDrain`]. +/// +/// `resume` is an optional [`PipelineNodeState`] from a continuation token. +/// When present, ranges whose `max_exclusive <= current_min_epk` are skipped +/// and the server continuation from `left_most` is propagated to the front +/// (resumed) leaf only. +pub(crate) async fn build_sequential_drain( + query_plan: &QueryPlan, + topology_provider: &mut dyn TopologyProvider, + operation: &Arc, + resume: Option, +) -> azure_core::Result { + validate_query_plan(query_plan)?; + + let resume = match resume { + None => None, + Some(PipelineNodeState::Drained) => { + return Ok(Pipeline::new(Box::new(DrainedLeaf))); + } + Some(PipelineNodeState::SequentialDrain { + current_min_epk, + left_most, + }) => { + let server_continuation = match *left_most { + PipelineNodeState::Request { + server_continuation, + } => server_continuation, + PipelineNodeState::Drained => None, + other => { + return Err(azure_core::Error::with_message( + azure_core::error::ErrorKind::DataConversion, + format!( + "continuation token has unsupported nested shape inside SequentialDrain: {}", + snapshot_kind(&other) + ), + )); + } + }; + Some(ResumeCursor { + current_min_epk: EffectivePartitionKey::from(current_min_epk), + server_continuation, + }) + } + Some(PipelineNodeState::Request { + server_continuation, + }) => { + // A bare Request snapshot means the cross-partition query had only + // a single child — apply it as a cursor at the minimum EPK. + Some(ResumeCursor { + current_min_epk: EffectivePartitionKey::min(), + server_continuation, + }) + } + }; + + // Convert query ranges to FeedRanges and resolve against topology. + let mut request_nodes: Vec> = Vec::new(); + let mut resume = resume; + for query_range in &query_plan.query_ranges { + let min = EffectivePartitionKey::from(query_range.min.as_str()); + let max = EffectivePartitionKey::from(query_range.max.as_str()); + let feed_range = FeedRange::new(min, max); + let resolved = topology_provider + .resolve_ranges(&feed_range, PartitionRoutingRefresh::UseCached) + .await?; + + for resolved_range in resolved { + // Skip ranges that are entirely below the resume cursor. + if let Some(cursor) = resume.as_ref() { + if resolved_range.range.max_exclusive() <= &cursor.current_min_epk { + continue; + } + } + + // Carry the server continuation onto the first surviving leaf, + // then clear it so subsequent leaves start fresh. + let initial_continuation = resume.as_mut().and_then(|c| c.server_continuation.take()); + let target = RequestTarget::EffectivePartitionKeyRange { + range: resolved_range.range, + partition_key_range_id: resolved_range.partition_key_range_id, + }; + request_nodes.push(Box::new(Request::new( + Arc::clone(operation), + target, + initial_continuation, + ))); + } + } + + // TODO: enforce max fan-out (default 100, configurable). See FEED_OPERATIONS_REQS.md §3. + + if request_nodes.is_empty() { + // Either the plan had no ranges or everything was below the cursor. + // The latter is a normal "fully drained" outcome — emit a drained leaf. + if resume.is_some() { + return Ok(Pipeline::new(Box::new(DrainedLeaf))); + } + return Err(azure_core::Error::with_message( + azure_core::error::ErrorKind::Other, + "query plan produced no partition ranges to query", + )); + } + + let root: Box = if request_nodes.len() == 1 { + request_nodes.into_iter().next().unwrap() + } else { + Box::new(SequentialDrain::new(request_nodes)) + }; + + Ok(Pipeline::new(root)) +} + +/// Resume cursor extracted from a `SequentialDrain` continuation snapshot. +struct ResumeCursor { + current_min_epk: EffectivePartitionKey, + server_continuation: Option, +} + +fn snapshot_kind(state: &PipelineNodeState) -> &'static str { + match state { + PipelineNodeState::Drained => "Drained", + PipelineNodeState::Request { .. } => "Request", + PipelineNodeState::SequentialDrain { .. } => "SequentialDrain", + } +} + +/// Validates that the query plan does not require features we don't yet support. +fn validate_query_plan(plan: &QueryPlan) -> azure_core::Result<()> { + if plan.hybrid_search_query_info.is_some() { + return Err(unsupported_feature("hybrid search queries")); + } + + if let Some(info) = &plan.query_info { + validate_query_info(info)?; + } + + Ok(()) +} + +fn validate_query_info(info: &QueryInfo) -> azure_core::Result<()> { + if info.top.is_some() { + return Err(unsupported_feature("TOP clause in cross-partition queries")); + } + if info.limit.is_some() { + return Err(unsupported_feature( + "LIMIT clause in cross-partition queries", + )); + } + if !info.order_by.is_empty() { + return Err(unsupported_feature("ORDER BY in cross-partition queries")); + } + if !info.aggregates.is_empty() { + return Err(unsupported_feature("aggregates in cross-partition queries")); + } + if !info.group_by_expressions.is_empty() { + return Err(unsupported_feature("GROUP BY in cross-partition queries")); + } + Ok(()) +} + +fn unsupported_feature(feature: &str) -> azure_core::Error { + azure_core::Error::with_message( + azure_core::error::ErrorKind::Other, + format!("unsupported query feature: {feature}"), + ) +} + +fn target_description(target: &OperationTarget) -> &'static str { + match target { + OperationTarget::None => "None", + OperationTarget::PartitionKey(_) => "PartitionKey", + OperationTarget::FeedRange(_) => "FeedRange", + } +} + +#[cfg(test)] +mod tests { + use std::borrow::Cow; + + use super::*; + use crate::{ + driver::dataflow::{mocks::*, query_plan::QueryRange, ResolvedRange}, + models::{ + effective_partition_key::EffectivePartitionKey, AccountReference, ContainerProperties, + ContainerReference, DatabaseReference, ItemReference, OperationType, PartitionKey, + PartitionKeyDefinition, ResourceType, SystemProperties, + }, + }; + + fn test_account() -> AccountReference { + AccountReference::with_master_key( + url::Url::parse("https://test.documents.azure.com:443/").unwrap(), + "dGVzdA==", + ) + } + + fn test_database() -> DatabaseReference { + DatabaseReference::from_name(test_account(), "db".to_owned()) + } + + fn test_partition_key_definition() -> PartitionKeyDefinition { + serde_json::from_str(r#"{"paths":["/pk"]}"#).unwrap() + } + + fn test_container_props() -> ContainerProperties { + ContainerProperties { + id: Cow::Owned("coll".into()), + partition_key: test_partition_key_definition(), + system_properties: SystemProperties::default(), + } + } + + fn test_container() -> ContainerReference { + ContainerReference::new( + test_account(), + "db", + "db_rid", + "coll", + "coll_rid", + &test_container_props(), + ) + } + + fn cross_partition_query_operation() -> CosmosOperation { + CosmosOperation::query_items( + test_container(), + OperationTarget::FeedRange(FeedRange::full()), + ) + .with_body(br#"{"query":"SELECT * FROM c"}"#.to_vec()) + } + + // --- build_trivial_pipeline tests --- + + #[test] + fn plans_non_partitioned_pipeline_for_database_read() { + let op = CosmosOperation::read_database(test_database()); + let pipeline = build_trivial_pipeline(Arc::new(op), None).unwrap(); + + let request = pipeline.root().downcast_ref::().unwrap(); + assert_eq!(*request.target(), RequestTarget::NonPartitioned); + assert_eq!(request.operation().operation_type(), OperationType::Read); + assert_eq!(request.operation().resource_type(), ResourceType::Database); + } + + #[test] + fn plans_logical_partition_pipeline_for_item_read() { + let pk = PartitionKey::from("pk-value"); + let item = ItemReference::from_name(&test_container(), pk.clone(), "doc1"); + let op = CosmosOperation::read_item(item); + let pipeline = build_trivial_pipeline(Arc::new(op), None).unwrap(); + + let request = pipeline.root().downcast_ref::().unwrap(); + assert_eq!( + *request.target(), + RequestTarget::LogicalPartitionKey(pk.clone()) + ); + assert_eq!(request.operation().operation_type(), OperationType::Read); + assert_eq!(request.operation().resource_type(), ResourceType::Document); + } + + #[test] + fn rejects_feed_range_target() { + let op = CosmosOperation::read_all_items_cross_partition(test_container()); + + // In debug builds, this panics via debug_assert; in release builds it returns Err. + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + build_trivial_pipeline(Arc::new(op), None) + })); + + match result { + // Panicked in debug mode (expected) + Err(_) if cfg!(debug_assertions) => {} + // Panicked in release mode (bad) + Err(_) => panic!("did not expect panic for FeedRange target"), + // Returned Err in release mode (also acceptable) + Ok(Err(err)) => { + assert_eq!( + err.to_string(), + "FeedRange targeting requires a fan-out pipeline; \ + use plan_operation for cross-partition queries" + ); + } + _ => panic!("expected error or panic for FeedRange target"), + } + } + + // --- build_sequential_drain tests --- + + /// Shorthand to build a `QueryRange` from hex-prefix EPK strings. + fn qr(min: &str, max: &str) -> QueryRange { + QueryRange { + min: min.to_string(), + max: max.to_string(), + is_min_inclusive: true, + is_max_inclusive: false, + } + } + + /// Shorthand to build a `ResolvedRange` from (min, max, pk_range_id). + fn rr(min: &str, max: &str, pk_range_id: &str) -> ResolvedRange { + ResolvedRange { + partition_key_range_id: pk_range_id.to_string(), + range: FeedRange::new( + EffectivePartitionKey::from(min), + EffectivePartitionKey::from(max), + ), + } + } + + /// Builds a query plan with the given query ranges (and no query info). + fn plan_with_ranges(ranges: Vec) -> QueryPlan { + QueryPlan { + partitioned_query_execution_info_version: 1, + query_info: None, + query_ranges: ranges, + hybrid_search_query_info: None, + } + } + + /// Asserts that the pipeline is a single `Request` targeting the expected EPK range. + fn assert_single_request( + pipeline: &Pipeline, + expected_min: &str, + expected_max: &str, + expected_pk_range_id: &str, + ) { + let request = pipeline + .root() + .downcast_ref::() + .expect("expected single Request root"); + assert_eq!( + *request.target(), + RequestTarget::EffectivePartitionKeyRange { + range: FeedRange::new( + EffectivePartitionKey::from(expected_min), + EffectivePartitionKey::from(expected_max), + ), + partition_key_range_id: expected_pk_range_id.to_string(), + } + ); + } + + /// Asserts that the pipeline is a `SequentialDrain` containing `Request` nodes + /// targeting the given EPK ranges (in order). + fn assert_drain_requests(pipeline: Pipeline, expected: &[(&str, &str, &str)]) { + let drain = pipeline + .into_root() + .downcast::() + .expect("expected SequentialDrain root"); + let children = drain.into_children(); + assert_eq!( + children.len(), + expected.len(), + "expected {} request nodes, got {}", + expected.len(), + children.len(), + ); + for (child, &(min, max, pk_range_id)) in children.into_iter().zip(expected) { + let request = child + .downcast::() + .expect("expected Request child node"); + assert_eq!( + *request.target(), + RequestTarget::EffectivePartitionKeyRange { + range: FeedRange::new( + EffectivePartitionKey::from(min), + EffectivePartitionKey::from(max), + ), + partition_key_range_id: pk_range_id.to_string(), + }, + "mismatch for pk range {pk_range_id}" + ); + } + } + + #[tokio::test] + async fn builds_single_node_pipeline_for_one_partition() { + let plan = plan_with_ranges(vec![qr("", "FF")]); + let op = cross_partition_query_operation(); + let mut topology = MockTopologyProvider::new(vec![Ok(vec![rr("", "FF", "pkrange-0")])]); + + let pipeline = build_sequential_drain(&plan, &mut topology, &Arc::new(op), None) + .await + .unwrap(); + assert_single_request(&pipeline, "", "FF", "pkrange-0"); + } + + #[tokio::test] + async fn builds_sequential_drain_for_multiple_partitions() { + // Query targets full range, topology has two partitions split at "80". + let plan = plan_with_ranges(vec![qr("", "FF")]); + let op = cross_partition_query_operation(); + let mut topology = MockTopologyProvider::new(vec![Ok(vec![ + rr("", "80", "pkrange-left"), + rr("80", "FF", "pkrange-right"), + ])]); + + let pipeline = build_sequential_drain(&plan, &mut topology, &Arc::new(op), None) + .await + .unwrap(); + assert_drain_requests( + pipeline, + &[("", "80", "pkrange-left"), ("80", "FF", "pkrange-right")], + ); + } + + #[tokio::test] + async fn builds_pipeline_for_multiple_query_ranges() { + // Query plan specifies two disjoint query ranges; each resolves to one partition. + let plan = plan_with_ranges(vec![qr("", "40"), qr("80", "FF")]); + let op = cross_partition_query_operation(); + let mut topology = MockTopologyProvider::new(vec![ + Ok(vec![rr("", "40", "pkrange-A")]), + Ok(vec![rr("80", "FF", "pkrange-C")]), + ]); + + let pipeline = build_sequential_drain(&plan, &mut topology, &Arc::new(op), None) + .await + .unwrap(); + assert_drain_requests( + pipeline, + &[("", "40", "pkrange-A"), ("80", "FF", "pkrange-C")], + ); + } + + #[tokio::test] + async fn query_range_spans_multiple_topology_partitions() { + // A single query range [00, C0) spans three topology partitions. + let plan = plan_with_ranges(vec![qr("00", "C0")]); + let op = cross_partition_query_operation(); + let mut topology = MockTopologyProvider::new(vec![Ok(vec![ + rr("00", "40", "pkrange-1"), + rr("40", "80", "pkrange-2"), + rr("80", "C0", "pkrange-3"), + ])]); + + let pipeline = build_sequential_drain(&plan, &mut topology, &Arc::new(op), None) + .await + .unwrap(); + assert_drain_requests( + pipeline, + &[ + ("00", "40", "pkrange-1"), + ("40", "80", "pkrange-2"), + ("80", "C0", "pkrange-3"), + ], + ); + } + + #[tokio::test] + async fn multiple_query_ranges_each_spanning_multiple_partitions() { + // Two query ranges, each resolving to multiple partitions. The resulting + // pipeline should have all resolved ranges in order. + let plan = plan_with_ranges(vec![qr("", "60"), qr("A0", "FF")]); + let op = cross_partition_query_operation(); + let mut topology = MockTopologyProvider::new(vec![ + // First query range [, 60) spans two partitions. + Ok(vec![ + rr("", "30", "pkrange-alpha"), + rr("30", "60", "pkrange-beta"), + ]), + // Second query range [A0, FF) spans two partitions. + Ok(vec![ + rr("A0", "D0", "pkrange-gamma"), + rr("D0", "FF", "pkrange-delta"), + ]), + ]); + + let pipeline = build_sequential_drain(&plan, &mut topology, &Arc::new(op), None) + .await + .unwrap(); + assert_drain_requests( + pipeline, + &[ + ("", "30", "pkrange-alpha"), + ("30", "60", "pkrange-beta"), + ("A0", "D0", "pkrange-gamma"), + ("D0", "FF", "pkrange-delta"), + ], + ); + } + + #[tokio::test] + async fn topology_partition_wider_than_query_range() { + // The topology partition [, FF) is wider than query range [20, 80). + // The resolved range matches the topology, not the query range. + let plan = plan_with_ranges(vec![qr("20", "80")]); + let op = cross_partition_query_operation(); + let mut topology = MockTopologyProvider::new(vec![Ok(vec![rr("", "FF", "pkrange-wide")])]); + + let pipeline = build_sequential_drain(&plan, &mut topology, &Arc::new(op), None) + .await + .unwrap(); + assert_single_request(&pipeline, "", "FF", "pkrange-wide"); + } + + #[tokio::test] + async fn rejects_query_plan_with_top() { + let plan = QueryPlan { + query_info: Some(QueryInfo { + top: Some(10), + ..Default::default() + }), + ..plan_with_ranges(vec![qr("", "FF")]) + }; + let op = cross_partition_query_operation(); + let mut topology = NoopTopologyProvider; + + let err = build_sequential_drain(&plan, &mut topology, &Arc::new(op), None) + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "unsupported query feature: TOP clause in cross-partition queries" + ); + } + + #[tokio::test] + async fn rejects_query_plan_with_limit() { + let plan = QueryPlan { + query_info: Some(QueryInfo { + limit: Some(20), + ..Default::default() + }), + ..plan_with_ranges(vec![qr("", "FF")]) + }; + let op = cross_partition_query_operation(); + let mut topology = NoopTopologyProvider; + + let err = build_sequential_drain(&plan, &mut topology, &Arc::new(op), None) + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "unsupported query feature: LIMIT clause in cross-partition queries" + ); + } + + #[tokio::test] + async fn rejects_query_plan_with_order_by() { + use super::super::query_plan::SortOrder; + let plan = QueryPlan { + query_info: Some(QueryInfo { + order_by: vec![SortOrder::Ascending], + ..Default::default() + }), + ..plan_with_ranges(vec![qr("", "FF")]) + }; + let op = cross_partition_query_operation(); + let mut topology = NoopTopologyProvider; + + let err = build_sequential_drain(&plan, &mut topology, &Arc::new(op), None) + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "unsupported query feature: ORDER BY in cross-partition queries" + ); + } + + #[tokio::test] + async fn rejects_query_plan_with_aggregates() { + let plan = QueryPlan { + query_info: Some(QueryInfo { + aggregates: vec!["Count".to_string()], + ..Default::default() + }), + ..plan_with_ranges(vec![qr("", "FF")]) + }; + let op = cross_partition_query_operation(); + let mut topology = NoopTopologyProvider; + + let err = build_sequential_drain(&plan, &mut topology, &Arc::new(op), None) + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "unsupported query feature: aggregates in cross-partition queries" + ); + } + + #[tokio::test] + async fn rejects_query_plan_with_group_by() { + let plan = QueryPlan { + query_info: Some(QueryInfo { + group_by_expressions: vec!["c.category".to_string()], + ..Default::default() + }), + ..plan_with_ranges(vec![qr("", "FF")]) + }; + let op = cross_partition_query_operation(); + let mut topology = NoopTopologyProvider; + + let err = build_sequential_drain(&plan, &mut topology, &Arc::new(op), None) + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "unsupported query feature: GROUP BY in cross-partition queries" + ); + } + + #[tokio::test] + async fn rejects_query_plan_with_hybrid_search() { + let plan = QueryPlan { + hybrid_search_query_info: Some(super::super::query_plan::HybridSearchQueryInfo { + global_statistics_query: "SELECT COUNT(1) FROM c".to_string(), + component_query_infos: vec![], + component_weights: vec![], + skip: None, + take: Some(10), + requires_global_statistics: true, + }), + ..plan_with_ranges(vec![qr("", "FF")]) + }; + let op = cross_partition_query_operation(); + let mut topology = NoopTopologyProvider; + + let err = build_sequential_drain(&plan, &mut topology, &Arc::new(op), None) + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "unsupported query feature: hybrid search queries" + ); + } + + #[tokio::test] + async fn accepts_query_plan_with_no_query_info() { + let plan = plan_with_ranges(vec![qr("", "FF")]); + let op = cross_partition_query_operation(); + let mut topology = MockTopologyProvider::new(vec![Ok(vec![rr("", "FF", "pkrange-0")])]); + + let pipeline = build_sequential_drain(&plan, &mut topology, &Arc::new(op), None) + .await + .unwrap(); + assert_single_request(&pipeline, "", "FF", "pkrange-0"); + } + + #[tokio::test] + async fn rejects_empty_query_ranges() { + let plan = plan_with_ranges(vec![]); + let op = cross_partition_query_operation(); + let mut topology = NoopTopologyProvider; + + let err = build_sequential_drain(&plan, &mut topology, &Arc::new(op), None) + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "query plan produced no partition ranges to query" + ); + } + + #[tokio::test] + async fn propagates_topology_resolution_error() { + let plan = plan_with_ranges(vec![qr("", "FF")]); + let op = cross_partition_query_operation(); + let mut topology = MockTopologyProvider::new(vec![Err(azure_core::Error::with_message( + azure_core::error::ErrorKind::Other, + "topology resolution failed", + ))]); + + let err = build_sequential_drain(&plan, &mut topology, &Arc::new(op), None) + .await + .unwrap_err(); + assert_eq!(err.to_string(), "topology resolution failed"); + } + + // ----------------------------------------------------------------- + // Resume tests + // ----------------------------------------------------------------- + + #[tokio::test] + async fn resume_drained_state_yields_drained_pipeline() { + let plan = plan_with_ranges(vec![qr("", "FF")]); + let op = cross_partition_query_operation(); + let mut topology = MockTopologyProvider::new(vec![Ok(vec![rr("", "FF", "pkrange-0")])]); + + let pipeline = build_sequential_drain( + &plan, + &mut topology, + &Arc::new(op), + Some(PipelineNodeState::Drained), + ) + .await + .unwrap(); + + // The drained pipeline immediately yields no pages. + assert!(matches!( + pipeline.snapshot_state(), + PipelineNodeState::Drained + )); + } + + #[tokio::test] + async fn resume_skips_ranges_below_cursor() { + let plan = plan_with_ranges(vec![qr("", "FF")]); + let op = cross_partition_query_operation(); + let mut topology = MockTopologyProvider::new(vec![Ok(vec![ + rr("", "55", "pk-a"), + rr("55", "AA", "pk-b"), + rr("AA", "FF", "pk-c"), + ])]); + + // Cursor sitting at the first byte of the second range — the first + // range (max_exclusive == "55") must be skipped, the others kept. + let resume = PipelineNodeState::SequentialDrain { + current_min_epk: "55".to_owned(), + left_most: Box::new(PipelineNodeState::Request { + server_continuation: None, + }), + }; + + let pipeline = build_sequential_drain(&plan, &mut topology, &Arc::new(op), Some(resume)) + .await + .unwrap(); + assert_drain_requests(pipeline, &[("55", "AA", "pk-b"), ("AA", "FF", "pk-c")]); + } + + #[tokio::test] + async fn resume_propagates_server_continuation_to_first_surviving_leaf_only() { + let plan = plan_with_ranges(vec![qr("", "FF")]); + let op = cross_partition_query_operation(); + let mut topology = MockTopologyProvider::new(vec![Ok(vec![ + rr("", "55", "pk-a"), + rr("55", "AA", "pk-b"), + rr("AA", "FF", "pk-c"), + ])]); + + let resume = PipelineNodeState::SequentialDrain { + current_min_epk: "55".to_owned(), + left_most: Box::new(PipelineNodeState::Request { + server_continuation: Some("server-token-xyz".to_owned()), + }), + }; + + let pipeline = build_sequential_drain(&plan, &mut topology, &Arc::new(op), Some(resume)) + .await + .unwrap(); + let snapshot = pipeline.snapshot_state(); + let PipelineNodeState::SequentialDrain { left_most, .. } = snapshot else { + panic!("expected SequentialDrain snapshot, got {snapshot:?}"); + }; + assert_eq!( + *left_most, + PipelineNodeState::Request { + server_continuation: Some("server-token-xyz".to_owned()), + }, + "front leaf must carry the resumed server continuation", + ); + } + + #[tokio::test] + async fn resume_with_cursor_past_all_ranges_yields_drained_pipeline() { + let plan = plan_with_ranges(vec![qr("", "FF")]); + let op = cross_partition_query_operation(); + let mut topology = MockTopologyProvider::new(vec![Ok(vec![rr("", "55", "pk-a")])]); + + let resume = PipelineNodeState::SequentialDrain { + current_min_epk: "FF".to_owned(), + left_most: Box::new(PipelineNodeState::Drained), + }; + + let pipeline = build_sequential_drain(&plan, &mut topology, &Arc::new(op), Some(resume)) + .await + .unwrap(); + assert!(matches!( + pipeline.snapshot_state(), + PipelineNodeState::Drained + )); + } + + #[tokio::test] + async fn resume_rejects_nested_sequential_drain_inside_left_most() { + let plan = plan_with_ranges(vec![qr("", "FF")]); + let op = cross_partition_query_operation(); + let mut topology = MockTopologyProvider::new(vec![Ok(vec![rr("", "FF", "pk-a")])]); + + let resume = PipelineNodeState::SequentialDrain { + current_min_epk: "00".to_owned(), + left_most: Box::new(PipelineNodeState::SequentialDrain { + current_min_epk: "00".to_owned(), + left_most: Box::new(PipelineNodeState::Request { + server_continuation: None, + }), + }), + }; + + let err = build_sequential_drain(&plan, &mut topology, &Arc::new(op), Some(resume)) + .await + .unwrap_err(); + assert!( + err.to_string().contains("unsupported nested shape"), + "unexpected error message: {err}", + ); + } +} diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/query_plan.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/query_plan.rs new file mode 100644 index 00000000000..d2a84a95b07 --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/query_plan.rs @@ -0,0 +1,266 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//! Backend query plan models. +//! +//! These types model the response from the Cosmos DB Gateway when issuing a +//! query plan request (`OperationType::QueryPlan`). The planner uses them to +//! determine partition targeting, detect unsupported query features, and build +//! the dataflow pipeline. + +use std::collections::HashMap; + +use serde::Deserialize; + +/// The response returned by the Gateway for a query plan request. +#[derive(Debug, Default, Deserialize)] +#[serde(rename_all = "camelCase")] +#[allow(dead_code)] // Wire-format fields; not all are consumed today. +pub(crate) struct QueryPlan { + /// The version of the query plan format. + pub partitioned_query_execution_info_version: usize, + + /// Detailed query information (ordering, aggregates, rewrites, etc.). + #[serde(default)] + pub query_info: Option, + + /// The EPK ranges that the query references. + /// + /// Used by the planner to limit which physical partitions get queried. + pub query_ranges: Vec, + + /// Information about hybrid search queries, if applicable. + pub hybrid_search_query_info: Option, +} + +/// Information about a hybrid search query. +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +#[allow(dead_code)] // Wire-format fields; hybrid search isn't fully wired yet. +pub(crate) struct HybridSearchQueryInfo { + /// The query used for global statistics gathering. + pub global_statistics_query: String, + + /// Individual component queries that make up the hybrid search. + pub component_query_infos: Vec, + + /// Weights assigned to each component query. + #[serde(default)] + pub component_weights: Vec, + + /// Number of results to skip. + pub skip: Option, + + /// Number of results to take (always present for hybrid search). + pub take: Option, + + /// Whether global statistics are required. + pub requires_global_statistics: bool, +} + +/// The kind of DISTINCT tracking required by the query. +#[derive(Debug, Deserialize, Default, PartialEq, Eq)] +pub(crate) enum DistinctType { + /// No deduplication required. + #[default] + None, + + /// Order-preserving deduplication. + Ordered, + + /// Order-independent deduplication. + Unordered, +} + +/// Detailed query plan information. +#[derive(Debug, Deserialize, Default)] +#[serde(default)] +#[serde(rename_all = "camelCase")] +pub(crate) struct QueryInfo { + /// The kind of DISTINCT clause, if any. + pub distinct_type: DistinctType, + + /// `TOP` clause limit. + pub top: Option, + + /// `OFFSET` clause value. + pub offset: Option, + + /// `LIMIT` clause value (from `OFFSET`/`LIMIT`). + pub limit: Option, + + /// Sort orders for `ORDER BY` expressions. + pub order_by: Vec, + + /// Expressions used by `ORDER BY` clauses. + pub order_by_expressions: Vec, + + /// Expressions used by `GROUP BY` clauses. + pub group_by_expressions: Vec, + + /// Aliases used by `GROUP BY` clauses. + pub group_by_aliases: Vec, + + /// Aggregates used in the `SELECT` portion of a `GROUP BY` query. + pub aggregates: Vec, + + /// Mapping from GROUP BY aliases to aggregate types. + pub group_by_alias_to_aggregate_type: HashMap, + + /// Rewritten form of the query for single-partition sub-queries. + /// + /// When non-empty, this should be used instead of the original query text + /// for individual partition requests. + pub rewritten_query: String, + + /// Whether the query contains a `SELECT VALUE` clause. + pub has_select_value: bool, + + /// Whether the query contains a non-streaming `ORDER BY`. + pub has_non_streaming_order_by: bool, +} + +/// Sort order for an `ORDER BY` expression. +#[derive(Debug, Deserialize, Clone, Copy, PartialEq, Eq)] +pub(crate) enum SortOrder { + Ascending, + Descending, +} + +/// An EPK range covered by the query. +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +#[allow(dead_code)] // Inclusivity flags are wire-format; planner treats ranges uniformly. +pub(crate) struct QueryRange { + /// The minimum EPK value. + pub min: String, + + /// The maximum EPK value. + pub max: String, + + /// Whether the minimum value is inclusive. + pub is_min_inclusive: bool, + + /// Whether the maximum value is inclusive. + pub is_max_inclusive: bool, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn deserializes_minimal_query_plan() { + let json = r#"{ + "partitionedQueryExecutionInfoVersion": 1, + "queryRanges": [ + { + "min": "", + "max": "FF", + "isMinInclusive": true, + "isMaxInclusive": false + } + ] + }"#; + let plan: QueryPlan = serde_json::from_str(json).unwrap(); + assert_eq!(plan.partitioned_query_execution_info_version, 1); + assert!(plan.query_info.is_none()); + assert!(plan.hybrid_search_query_info.is_none()); + assert_eq!(plan.query_ranges.len(), 1); + assert_eq!(plan.query_ranges[0].min, ""); + assert_eq!(plan.query_ranges[0].max, "FF"); + assert!(plan.query_ranges[0].is_min_inclusive); + assert!(!plan.query_ranges[0].is_max_inclusive); + } + + #[test] + fn deserializes_query_plan_with_order_by() { + let json = r#"{ + "partitionedQueryExecutionInfoVersion": 2, + "queryInfo": { + "orderBy": ["Ascending", "Descending"], + "orderByExpressions": ["c.name", "c.age"], + "rewrittenQuery": "SELECT c.name, c.age FROM c ORDER BY c.name ASC, c.age DESC" + }, + "queryRanges": [] + }"#; + let plan: QueryPlan = serde_json::from_str(json).unwrap(); + let info = plan.query_info.unwrap(); + assert_eq!( + info.order_by, + vec![SortOrder::Ascending, SortOrder::Descending] + ); + assert_eq!(info.order_by_expressions, vec!["c.name", "c.age"]); + } + + #[test] + fn deserializes_query_plan_with_top_and_aggregates() { + let json = r#"{ + "partitionedQueryExecutionInfoVersion": 1, + "queryInfo": { + "top": 10, + "aggregates": ["Count"], + "distinctType": "Ordered" + }, + "queryRanges": [] + }"#; + let plan: QueryPlan = serde_json::from_str(json).unwrap(); + let info = plan.query_info.unwrap(); + assert_eq!(info.top, Some(10)); + assert_eq!(info.aggregates, vec!["Count"]); + assert_eq!(info.distinct_type, DistinctType::Ordered); + } + + #[test] + fn deserializes_query_plan_with_hybrid_search() { + let json = r#"{ + "partitionedQueryExecutionInfoVersion": 1, + "queryRanges": [], + "hybridSearchQueryInfo": { + "globalStatisticsQuery": "SELECT COUNT(1) FROM c", + "componentQueryInfos": [], + "componentWeights": [0.5, 0.5], + "skip": null, + "take": 10, + "requiresGlobalStatistics": true + } + }"#; + let plan: QueryPlan = serde_json::from_str(json).unwrap(); + let hybrid = plan.hybrid_search_query_info.unwrap(); + assert_eq!(hybrid.global_statistics_query, "SELECT COUNT(1) FROM c"); + assert_eq!(hybrid.component_weights, vec![0.5, 0.5]); + assert_eq!(hybrid.take, Some(10)); + assert!(hybrid.requires_global_statistics); + } + + #[test] + fn deserializes_query_plan_with_offset_limit() { + let json = r#"{ + "partitionedQueryExecutionInfoVersion": 1, + "queryInfo": { + "offset": 5, + "limit": 20 + }, + "queryRanges": [] + }"#; + let plan: QueryPlan = serde_json::from_str(json).unwrap(); + let info = plan.query_info.unwrap(); + assert_eq!(info.offset, Some(5)); + assert_eq!(info.limit, Some(20)); + } + + #[test] + fn deserializes_multiple_query_ranges() { + let json = r#"{ + "partitionedQueryExecutionInfoVersion": 1, + "queryRanges": [ + { "min": "", "max": "40", "isMinInclusive": true, "isMaxInclusive": false }, + { "min": "80", "max": "FF", "isMinInclusive": true, "isMaxInclusive": false } + ] + }"#; + let plan: QueryPlan = serde_json::from_str(json).unwrap(); + assert_eq!(plan.query_ranges.len(), 2); + assert_eq!(plan.query_ranges[0].max, "40"); + assert_eq!(plan.query_ranges[1].min, "80"); + } +} diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/request.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/request.rs new file mode 100644 index 00000000000..fa019b047f5 --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/request.rs @@ -0,0 +1,788 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//! Request leaf node for the dataflow pipeline. + +use std::sync::Arc; + +use async_trait::async_trait; +use azure_core::http::StatusCode; + +use crate::models::{CosmosOperation, CosmosResponse, FeedRange, PartitionKey, SubStatusCode}; + +use super::{ + PageResult, PartitionRoutingRefresh, PipelineContext, PipelineNode, PipelineNodeState, + ResolvedRange, +}; + +/// The target of a request node. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum RequestTarget { + /// The request is to a non-partitioned resource (databases, containers, offers, etc.) + NonPartitioned, + + /// A single logical partition key. + LogicalPartitionKey(PartitionKey), + + /// A physical partition key range whose full EPK coverage is owned by this request. + PartitionKeyRange { + /// Full EPK range covered by the physical partition this request owns. + range: FeedRange, + /// Partition key range ID for the owned physical partition. + partition_key_range_id: String, + }, + + /// An EPK slice that must be queried inside a broader physical partition key range. + EffectivePartitionKeyRange { + /// EPK range scoped by this request. + range: FeedRange, + /// Partition key range ID containing `range`. + partition_key_range_id: String, + }, +} + +impl RequestTarget { + /// Returns the EPK slice owned by this request target, if any. + fn owned_range(&self) -> Option<&FeedRange> { + match self { + RequestTarget::PartitionKeyRange { range, .. } + | RequestTarget::EffectivePartitionKeyRange { range, .. } => Some(range), + _ => None, + } + } + + /// Returns `true` if this target's EPK range starts at the same point as `parent_range`. + fn covers_start_of(&self, parent_range: &FeedRange) -> bool { + self.owned_range() + .is_some_and(|range| range.min_inclusive() == parent_range.min_inclusive()) + } +} + +fn intersect_feed_ranges(left: &FeedRange, right: &FeedRange) -> Option { + let min = if left.min_inclusive() >= right.min_inclusive() { + left.min_inclusive().clone() + } else { + right.min_inclusive().clone() + }; + let max = if left.max_exclusive() <= right.max_exclusive() { + left.max_exclusive().clone() + } else { + right.max_exclusive().clone() + }; + + (min < max).then(|| FeedRange::new(min, max)) +} + +#[derive(Debug, PartialEq, Eq)] +enum RequestState { + /// No request has been sent yet. The next page will trigger the initial request. + Initial, + + /// A request has been sent and a server continuation token has been received, but not all pages have been drained yet. The next page will trigger a request with the continuation token. + Continuing { continuation: String }, + + /// All pages have been drained. No further requests will be sent. + Drained, +} + +/// Leaf node that executes one Cosmos DB request per page. +/// +/// The `operation` is held as an `Arc` so the same logical +/// operation can be shared across many `Request` nodes (e.g. in a fan-out +/// `SequentialDrain` over multiple partitions) without paying for one full +/// `CosmosOperation` copy per node. Per-request differences are applied at +/// execution time via [`OperationOverrides`](crate::pipeline::OperationOverrides), +/// not by mutating the shared operation. +pub(crate) struct Request { + operation: Arc, + target: RequestTarget, + state: RequestState, +} + +impl Request { + /// Creates a request node. + pub(crate) fn new( + operation: Arc, + target: RequestTarget, + initial_continuation: Option, + ) -> Self { + let initial_state = if let Some(token) = initial_continuation { + RequestState::Continuing { + continuation: token, + } + } else { + RequestState::Initial + }; + Self { + operation, + target, + state: initial_state, + } + } + + #[cfg(test)] + /// Returns the operation this request node executes. + pub(crate) fn operation(&self) -> &CosmosOperation { + &self.operation + } + + #[cfg(test)] + /// Returns the target this request node uses for routing. + pub(crate) fn target(&self) -> &RequestTarget { + &self.target + } +} + +#[async_trait] +impl PipelineNode for Request { + async fn next_page( + &mut self, + context: &mut PipelineContext<'_>, + ) -> azure_core::Result { + tracing::trace!( + target = ?self.target, + state = ?self.state, + "executing request node" + ); + + let continuation = match &self.state { + RequestState::Initial => None, + RequestState::Continuing { continuation } => Some(continuation.clone()), + RequestState::Drained => return Ok(PageResult::Drained), + }; + + match context + .execute_request( + &self.operation, + self.target.clone(), + PartitionRoutingRefresh::UseCached, + continuation.clone(), + ) + .await + { + Ok(response) => Ok(self.handle_response(response)), + Err(error) if is_partition_topology_change(&error) => { + self.handle_partition_topology_change(context, error, continuation) + .await + } + Err(error) => Err(error), + } + } + + #[cfg(test)] + fn into_children(self) -> Vec> { + Vec::new() + } + + fn snapshot_state(&self) -> PipelineNodeState { + match &self.state { + RequestState::Initial => PipelineNodeState::Request { + server_continuation: None, + }, + RequestState::Continuing { continuation } => PipelineNodeState::Request { + server_continuation: Some(continuation.clone()), + }, + RequestState::Drained => PipelineNodeState::Drained, + } + } + + fn feed_range(&self) -> Option<&FeedRange> { + self.target.owned_range() + } +} +impl Request { + fn handle_response(&mut self, response: CosmosResponse) -> PageResult { + let continuation = response.headers().continuation.clone(); + tracing::trace!( + target = ?self.target, + status = ?response.status(), + output_continuation = ?continuation, + "request completed" + ); + self.state = if let Some(token) = continuation { + RequestState::Continuing { + continuation: token, + } + } else { + RequestState::Drained + }; + tracing::trace!(target = ?self.target, state = ?self.state, "updated request state after response"); + let is_terminal = matches!(self.state, RequestState::Drained); + PageResult::Page { + response, + is_terminal, + } + } + + async fn handle_partition_topology_change( + &mut self, + context: &mut PipelineContext<'_>, + error: azure_core::Error, + continuation: Option, + ) -> azure_core::Result { + match &self.target { + RequestTarget::NonPartitioned => { + // Non-partitioned resources don't have partition topology changes. + Err(error) + } + RequestTarget::LogicalPartitionKey(_) => { + // This shouldn't really happen, but it's been observed. + // Since the original request had a logical partition key, + // the gateway should have been able to route the request + // to the correct partition even if it has split. + // But we can do a single retry without forcing a topology refresh to see if it succeeds. + context + .execute_request( + &self.operation, + self.target.clone(), + PartitionRoutingRefresh::ForceRefresh, + continuation, + ) + .await + .map(|response| { + tracing::trace!( + target = ?self.target, + status = ?response.status(), + "retry after logical partition key topology change succeeded" + ); + self.handle_response(response) + }) + } + RequestTarget::PartitionKeyRange { range, .. } + | RequestTarget::EffectivePartitionKeyRange { range, .. } => { + let range = range.clone(); + self.split_for_topology_change(context, &range).await + } + } + } + + /// Resolves the current topology for this node's EPK range and returns + /// a `SplitRequired` result with replacement nodes for each sub-range. + async fn split_for_topology_change( + &self, + context: &mut PipelineContext<'_>, + range: &FeedRange, + ) -> azure_core::Result { + let resolved = context + .resolve_ranges(range, PartitionRoutingRefresh::ForceRefresh) + .await?; + + let replacement_nodes: Vec> = resolved + .into_iter() + .map(|resolved_range| { + let ResolvedRange { + partition_key_range_id, + range: resolved_range, + } = resolved_range; + let owned_range = intersect_feed_ranges(&resolved_range, range).expect( + "topology provider must return ranges that overlap the request's owned EPK range", + ); + + let target = if owned_range == resolved_range { + RequestTarget::PartitionKeyRange { + range: resolved_range, + partition_key_range_id, + } + } else { + RequestTarget::EffectivePartitionKeyRange { + range: owned_range, + partition_key_range_id, + } + }; + // Carry over the server continuation to the first replacement that + // covers the same starting EPK. For a split, only the left-most child + // inherits the continuation since it resumes where this node left off. + let continuation = match (target.covers_start_of(range), &self.state) { + ( + true, + RequestState::Continuing { + continuation: latest_server_continuation, + }, + ) => Some(latest_server_continuation.clone()), + _ => None, + }; + Box::new(Request::new(self.operation.clone(), target, continuation)) + as Box + }) + .collect(); + + Ok(PageResult::SplitRequired { replacement_nodes }) + } +} + +// Partition topology changes are a specific subset of `Gone` substatus codes. +// Other substatus mappings live in `pipeline::retry_evaluation`; this one stays +// here because it drives pipeline-level repair (splitting a node into +// replacements) rather than per-attempt retry. +fn is_partition_topology_change(error: &azure_core::Error) -> bool { + match error.kind() { + azure_core::error::ErrorKind::HttpResponse { + status, error_code, .. + } if *status == StatusCode::Gone => error_code + .as_deref() + .and_then(|code| code.parse::().ok()) + .is_some_and(is_partition_topology_change_substatus), + _ => false, + } +} + +fn is_partition_topology_change_substatus(substatus: u32) -> bool { + matches!( + SubStatusCode::new(substatus), + SubStatusCode::PARTITION_KEY_RANGE_GONE + | SubStatusCode::COMPLETING_SPLIT + | SubStatusCode::COMPLETING_PARTITION_MIGRATION + ) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::driver::dataflow::{mocks::*, RequestExecutor, ResolvedRange, TopologyProvider}; + use crate::models::{effective_partition_key::EffectivePartitionKey, FeedRange}; + + #[derive(Clone, Debug)] + struct PhysicalPartitionSpec { + partition_key_range_id: String, + range: FeedRange, + } + + #[derive(Clone, Debug, PartialEq, Eq)] + struct RequestSpec { + target: RequestTarget, + continuation: Option, + } + + struct ScenarioTopologyProvider { + resolved_ranges: Vec, + } + + impl ScenarioTopologyProvider { + fn new(partitions: &[PhysicalPartitionSpec]) -> Self { + Self { + resolved_ranges: partitions + .iter() + .map(|partition| ResolvedRange { + partition_key_range_id: partition.partition_key_range_id.clone(), + range: partition.range.clone(), + }) + .collect(), + } + } + } + + impl TopologyProvider for ScenarioTopologyProvider { + fn resolve_ranges<'a>( + &'a mut self, + range: &'a FeedRange, + _refresh: PartitionRoutingRefresh, + ) -> futures::future::BoxFuture<'a, azure_core::Result>> { + let resolved = self + .resolved_ranges + .iter() + .filter(|candidate| { + candidate.range.min_inclusive() < range.max_exclusive() + && candidate.range.max_exclusive() > range.min_inclusive() + }) + .cloned() + .collect::>(); + + Box::pin(async move { + if resolved.is_empty() { + Err(azure_core::Error::with_message( + azure_core::error::ErrorKind::Other, + "scenario topology produced no overlapping ranges", + )) + } else { + Ok(resolved) + } + }) + } + } + + struct AlwaysGoneRequestExecutor; + + impl RequestExecutor for AlwaysGoneRequestExecutor { + fn execute_request<'a>( + &'a mut self, + _operation: &'a CosmosOperation, + _target: RequestTarget, + _partition_routing_refresh: PartitionRoutingRefresh, + _continuation: Option, + ) -> futures::future::BoxFuture<'a, azure_core::Result> { + Box::pin(async { Err(gone_error()) }) + } + } + + fn partition_key_range_target( + min: &str, + max: &str, + partition_key_range_id: &str, + ) -> RequestTarget { + RequestTarget::PartitionKeyRange { + range: FeedRange::new( + EffectivePartitionKey::from(min), + EffectivePartitionKey::from(max), + ), + partition_key_range_id: partition_key_range_id.to_string(), + } + } + + fn physical_partition( + min: &str, + max: &str, + partition_key_range_id: &str, + ) -> PhysicalPartitionSpec { + PhysicalPartitionSpec { + partition_key_range_id: partition_key_range_id.to_string(), + range: FeedRange::new( + EffectivePartitionKey::from(min), + EffectivePartitionKey::from(max), + ), + } + } + + fn request_spec(target: RequestTarget, continuation: Option<&str>) -> RequestSpec { + RequestSpec { + target, + continuation: continuation.map(str::to_owned), + } + } + + fn partition_key_request( + min: &str, + max: &str, + partition_key_range_id: &str, + continuation: Option<&str>, + ) -> RequestSpec { + request_spec( + partition_key_range_target(min, max, partition_key_range_id), + continuation, + ) + } + + fn effective_partition_key_request( + min: &str, + max: &str, + partition_key_range_id: &str, + continuation: Option<&str>, + ) -> RequestSpec { + request_spec( + effective_partition_key_range_target(min, max, partition_key_range_id), + continuation, + ) + } + + fn build_request(spec: RequestSpec) -> Request { + Request::new(Arc::new(operation()), spec.target, spec.continuation) + } + + fn snapshot_request(request: &Request) -> RequestSpec { + let continuation = match &request.state { + RequestState::Initial => None, + RequestState::Continuing { continuation } => Some(continuation.clone()), + RequestState::Drained => panic!("scenario helper should not produce drained requests"), + }; + + RequestSpec { + target: request.target.clone(), + continuation, + } + } + + async fn apply_topology_round( + requests: Vec, + partitions: &[PhysicalPartitionSpec], + ) -> Vec { + let mut executor = AlwaysGoneRequestExecutor; + let mut topology = ScenarioTopologyProvider::new(partitions); + let mut rewritten = Vec::new(); + + for mut request in requests { + let mut context = PipelineContext::new(&mut executor, Some(&mut topology)); + match request.next_page(&mut context).await.unwrap() { + PageResult::SplitRequired { replacement_nodes } => { + rewritten.extend(replacement_nodes.into_iter().map(|node| { + *node + .downcast::() + .expect("scenario helper should only produce request nodes") + })); + } + other => panic!("expected SplitRequired during topology rewrite, got {other:?}"), + } + } + + rewritten + } + + async fn assert_topology_rewrite( + initial_requests: Vec, + topology_rounds: Vec>, + expected_requests: Vec, + ) { + let mut current = initial_requests + .into_iter() + .map(build_request) + .collect::>(); + + // Each round applies a new physical partition layout to the current request list. + // We intentionally do not try to coalesce adjacent requests after repeated topology + // changes; these tests care about correctness of ownership, not optimality. + for partitions in topology_rounds { + current = apply_topology_round(current, &partitions).await; + } + + let actual = current.iter().map(snapshot_request).collect::>(); + assert_eq!(actual, expected_requests); + } + + fn effective_partition_key_range_target( + min: &str, + max: &str, + partition_key_range_id: &str, + ) -> RequestTarget { + RequestTarget::EffectivePartitionKeyRange { + range: FeedRange::new( + EffectivePartitionKey::from(min), + EffectivePartitionKey::from(max), + ), + partition_key_range_id: partition_key_range_id.to_string(), + } + } + + #[tokio::test] + async fn request_retries_logical_partition_key_topology_change_once() { + let mut request = Request::new(Arc::new(operation()), logical_partition_target(), None); + let mut executor = MockRequestExecutor::new(vec![Err(gone_error()), Ok(response(b"ok"))]); + let mut topology = NoopTopologyProvider; + let mut context = PipelineContext::new(&mut executor, Some(&mut topology)); + + let page = unwrap_page(request.next_page(&mut context).await); + + assert_eq!(page.body(), b"ok"); + assert_eq!( + executor.refresh_calls, + vec![ + PartitionRoutingRefresh::UseCached, + PartitionRoutingRefresh::ForceRefresh + ] + ); + assert_eq!(executor.continuation_calls, vec![None, None]); + } + + #[tokio::test] + async fn request_returns_second_logical_partition_key_topology_change() { + let mut request = Request::new(Arc::new(operation()), logical_partition_target(), None); + let mut executor = MockRequestExecutor::new(vec![Err(gone_error()), Err(gone_error())]); + let mut topology = NoopTopologyProvider; + let mut context = PipelineContext::new(&mut executor, Some(&mut topology)); + + let error = request.next_page(&mut context).await.unwrap_err(); + + assert!(is_partition_topology_change(&error)); + assert_eq!( + executor.refresh_calls, + vec![ + PartitionRoutingRefresh::UseCached, + PartitionRoutingRefresh::ForceRefresh + ] + ); + assert_eq!(executor.continuation_calls, vec![None, None]); + } + + #[tokio::test] + async fn request_does_not_retry_non_topology_gone() { + let mut request = Request::new(Arc::new(operation()), logical_partition_target(), None); + let mut executor = MockRequestExecutor::new(vec![Err(non_topology_gone_error())]); + let mut topology = NoopTopologyProvider; + let mut context = PipelineContext::new(&mut executor, Some(&mut topology)); + + let error = request.next_page(&mut context).await.unwrap_err(); + + assert!(!is_partition_topology_change(&error)); + assert_eq!( + executor.refresh_calls, + vec![PartitionRoutingRefresh::UseCached] + ); + assert_eq!(executor.continuation_calls, vec![None]); + } + + #[tokio::test] + async fn request_tracks_server_continuation_for_next_page() { + let mut request = Request::new(Arc::new(operation()), logical_partition_target(), None); + let mut executor = MockRequestExecutor::new(vec![ + Ok(response_with_continuation(b"page1", Some("token-1"))), + Ok(response_with_continuation(b"page2", Some("token-2"))), + ]); + let mut topology = NoopTopologyProvider; + let mut context = PipelineContext::new(&mut executor, Some(&mut topology)); + + let page1 = unwrap_page(request.next_page(&mut context).await); + let page2 = unwrap_page(request.next_page(&mut context).await); + + assert_eq!(page1.body(), b"page1"); + assert_eq!(page2.body(), b"page2"); + assert_eq!( + executor.continuation_calls, + vec![None, Some("token-1".to_string())] + ); + assert_eq!( + request.state, + RequestState::Continuing { + continuation: "token-2".to_string() + } + ); + } + + #[tokio::test] + async fn request_uses_restored_continuation_on_first_page() { + let mut request = Request::new( + Arc::new(operation()), + logical_partition_target(), + Some("restored-token".to_string()), + ); + let mut executor = MockRequestExecutor::new(vec![Ok(response(b"page"))]); + let mut topology = NoopTopologyProvider; + let mut context = PipelineContext::new(&mut executor, Some(&mut topology)); + + let page = unwrap_page(request.next_page(&mut context).await); + + assert_eq!(page.body(), b"page"); + assert_eq!( + executor.continuation_calls, + vec![Some("restored-token".to_string())] + ); + assert_eq!(request.state, RequestState::Drained); + } + + // ── Topology rewrite scenarios ─────────────────────────────────────── + + #[tokio::test] + async fn topology_rewrite_handles_simple_split() { + assert_topology_rewrite( + vec![partition_key_request("", "80", "0", Some("server-token"))], + vec![vec![ + physical_partition("", "40", "1"), + physical_partition("40", "80", "2"), + ]], + vec![ + partition_key_request("", "40", "1", Some("server-token")), + partition_key_request("40", "80", "2", None), + ], + ) + .await; + } + + #[tokio::test] + async fn topology_rewrite_handles_simple_merge() { + assert_topology_rewrite( + vec![ + partition_key_request("", "40", "left", Some("merge-token")), + partition_key_request("40", "80", "right", None), + ], + vec![vec![physical_partition("", "80", "merged")]], + vec![ + effective_partition_key_request("", "40", "merged", Some("merge-token")), + effective_partition_key_request("40", "80", "merged", None), + ], + ) + .await; + } + + #[tokio::test] + async fn topology_rewrite_leaves_unchanged_neighbors_alone() { + assert_topology_rewrite( + vec![ + partition_key_request("", "40", "left", Some("ct")), + partition_key_request("40", "80", "right", None), + ], + vec![vec![ + physical_partition("", "40", "left"), + physical_partition("40", "60", "right-a"), + physical_partition("60", "80", "right-b"), + ]], + vec![ + partition_key_request("", "40", "left", Some("ct")), + partition_key_request("40", "60", "right-a", None), + partition_key_request("60", "80", "right-b", None), + ], + ) + .await; + } + + #[tokio::test] + async fn topology_rewrite_can_return_from_merged_epk_slices_to_exact_pk_ranges() { + assert_topology_rewrite( + vec![ + effective_partition_key_request("", "40", "merged", Some("ct")), + effective_partition_key_request("40", "80", "merged", None), + ], + vec![vec![ + physical_partition("", "40", "left"), + physical_partition("40", "80", "right"), + ]], + vec![ + partition_key_request("", "40", "left", Some("ct")), + partition_key_request("40", "80", "right", None), + ], + ) + .await; + } + + #[tokio::test] + async fn topology_rewrite_handles_merge_then_different_split_mid_pipeline() { + assert_topology_rewrite( + vec![ + partition_key_request("00", "20", "a", Some("ct")), + partition_key_request("20", "40", "b", None), + partition_key_request("40", "80", "c", None), + ], + vec![ + vec![ + physical_partition("00", "40", "merged-left"), + physical_partition("40", "80", "c"), + ], + vec![ + physical_partition("00", "10", "split-a"), + physical_partition("10", "30", "split-b"), + physical_partition("30", "50", "split-c"), + physical_partition("50", "80", "split-d"), + ], + ], + vec![ + partition_key_request("00", "10", "split-a", Some("ct")), + effective_partition_key_request("10", "20", "split-b", None), + effective_partition_key_request("20", "30", "split-b", None), + effective_partition_key_request("30", "40", "split-c", None), + effective_partition_key_request("40", "50", "split-c", None), + partition_key_request("50", "80", "split-d", None), + ], + ) + .await; + } + + #[tokio::test] + async fn topology_provider_error_propagates() { + let mut request = Request::new(Arc::new(operation()), epk_range_target(), None); + let mut executor = MockRequestExecutor::new(vec![Err(gone_error())]); + let mut topology = MockTopologyProvider::new(vec![Err(azure_core::Error::with_message( + azure_core::error::ErrorKind::Other, + "topology fetch failed", + ))]); + let mut context = PipelineContext::new(&mut executor, Some(&mut topology)); + + let err = request.next_page(&mut context).await.unwrap_err(); + assert_eq!(err.to_string(), "topology fetch failed"); + } + + #[tokio::test] + async fn non_partitioned_topology_change_not_retried() { + let mut request = Request::new(Arc::new(operation()), RequestTarget::NonPartitioned, None); + let mut executor = MockRequestExecutor::new(vec![Err(gone_error())]); + let mut topology = NoopTopologyProvider; + let mut context = PipelineContext::new(&mut executor, Some(&mut topology)); + + let err = request.next_page(&mut context).await.unwrap_err(); + assert!(is_partition_topology_change(&err)); + } +} diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/snapshot.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/snapshot.rs new file mode 100644 index 00000000000..15fdc13739f --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/snapshot.rs @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//! Pipeline node snapshot state used to serialize / deserialize continuation +//! tokens. +//! +//! Each variant captures only the information required to reconstruct an +//! equivalent pipeline on resume. In particular, [`SequentialDrain`] only +//! preserves its left-most child plus an EPK floor; the planner reconstructs +//! the remaining (yet-to-drain) children from the operation's query ranges +//! and the current topology. + +use serde::{Deserialize, Serialize}; + +/// Serializable snapshot of a [`PipelineNode`](super::PipelineNode) subtree. +/// +/// The shape is intentionally open to future intermediate node kinds so a +/// parent does not need to know what type its child is — every node produces +/// a `PipelineNodeState` from `snapshot_state()`. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(tag = "kind", rename_all = "snake_case")] +pub(crate) enum PipelineNodeState { + /// The node has produced all of its pages. + Drained, + + /// A leaf request node. + /// + /// `server_continuation` is the opaque page token returned by the server + /// for the next page, or `None` when no request has yet been issued. + Request { + #[serde(default, skip_serializing_if = "Option::is_none")] + server_continuation: Option, + }, + + /// A sequential drain over EPK-ordered children. + /// + /// Only the left-most (currently-active) child's snapshot is preserved. + /// `current_min_epk` is the minimum EPK still left to drain; the planner + /// uses it to skip ranges that are entirely below the cursor on resume. + SequentialDrain { + current_min_epk: String, + left_most: Box, + }, +} diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/topology.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/topology.rs new file mode 100644 index 00000000000..c40cae64e4f --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/dataflow/topology.rs @@ -0,0 +1,286 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//! Topology provider adapter backed by the partition key range cache. + +use futures::future::BoxFuture; + +use crate::{ + driver::cache::{PartitionKeyRangeCache, PkRangeFetchResult}, + models::{ContainerReference, FeedRange}, +}; + +use super::{PartitionRoutingRefresh, ResolvedRange, TopologyProvider}; + +/// Adapts [`PartitionKeyRangeCache`] to the [`TopologyProvider`] trait. +/// +/// Holds a reference to the cache, the container being queried, and a function +/// that fetches partition key ranges from the service. On each +/// [`resolve_ranges`](TopologyProvider::resolve_ranges) call, it uses the +/// provided [`PartitionRoutingRefresh`](super::PartitionRoutingRefresh) to +/// decide whether to refresh the cache first. +/// +/// # Type parameters +/// +/// * `F` — `Fn(ContainerReference, Option) -> Fut` that fetches +/// pk-ranges from the service. Passed by reference to the cache so the +/// adapter can call it repeatedly without requiring `Clone`. +pub(crate) struct CachedTopologyProvider<'a, F> { + cache: &'a PartitionKeyRangeCache, + container: ContainerReference, + fetch_pk_ranges: F, +} + +impl<'a, F> CachedTopologyProvider<'a, F> { + /// Creates a topology provider backed by the partition key range cache. + pub(crate) fn new( + cache: &'a PartitionKeyRangeCache, + container: ContainerReference, + fetch_pk_ranges: F, + ) -> Self { + Self { + cache, + container, + fetch_pk_ranges, + } + } +} + +impl TopologyProvider for CachedTopologyProvider<'_, F> +where + F: Fn(ContainerReference, Option) -> Fut + Send + Sync, + Fut: std::future::Future> + Send, +{ + fn resolve_ranges<'a>( + &'a mut self, + range: &'a FeedRange, + refresh: PartitionRoutingRefresh, + ) -> BoxFuture<'a, azure_core::Result>> { + let force_refresh = matches!(refresh, PartitionRoutingRefresh::ForceRefresh); + Box::pin(async move { + let pk_ranges = self + .cache + .resolve_overlapping_ranges( + &self.container, + range.min_inclusive()..range.max_exclusive(), + force_refresh, + &self.fetch_pk_ranges, + ) + .await; + + let pk_ranges = match pk_ranges { + Some(ranges) if !ranges.is_empty() => ranges, + _ => { + return Err(azure_core::Error::with_message( + azure_core::error::ErrorKind::Other, + "failed to resolve partition key ranges from topology cache", + )); + } + }; + + Ok(pk_ranges + .into_iter() + .map(|pkr| ResolvedRange { + partition_key_range_id: pkr.id, + range: FeedRange::new(pkr.min_inclusive, pkr.max_exclusive), + }) + .collect()) + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::models::{ + effective_partition_key::EffectivePartitionKey, + partition_key_range::PartitionKeyRange as PkRange, ContainerProperties, + }; + + fn make_container() -> ContainerReference { + let account = crate::models::AccountReference::with_master_key( + url::Url::parse("https://test.documents.azure.com:443/").unwrap(), + "dGVzdA==", + ); + let props = ContainerProperties { + id: "c".into(), + partition_key: serde_json::from_str(r#"{"paths":["/pk"],"version":2}"#).unwrap(), + system_properties: Default::default(), + }; + ContainerReference::new(account, "db", "db_rid", "c", "c_rid", &props) + } + + async fn single_range_fetch( + _container: ContainerReference, + continuation: Option, + ) -> Option { + if continuation.is_some() { + Some(PkRangeFetchResult { + ranges: vec![], + continuation, + not_modified: true, + }) + } else { + Some(PkRangeFetchResult { + ranges: vec![PkRange::new("0".into(), "", "FF")], + continuation: Some("etag-1".to_string()), + not_modified: false, + }) + } + } + + async fn two_range_fetch( + _container: ContainerReference, + continuation: Option, + ) -> Option { + if continuation.is_some() { + Some(PkRangeFetchResult { + ranges: vec![], + continuation, + not_modified: true, + }) + } else { + Some(PkRangeFetchResult { + ranges: vec![ + PkRange::new("1".into(), "", "80"), + PkRange::new("2".into(), "80", "FF"), + ], + continuation: Some("etag-2".to_string()), + not_modified: false, + }) + } + } + + async fn three_range_fetch( + _container: ContainerReference, + continuation: Option, + ) -> Option { + if continuation.is_some() { + Some(PkRangeFetchResult { + ranges: vec![], + continuation, + not_modified: true, + }) + } else { + Some(PkRangeFetchResult { + ranges: vec![ + PkRange::new("1".into(), "", "40"), + PkRange::new("2".into(), "40", "80"), + PkRange::new("3".into(), "80", "FF"), + ], + continuation: Some("etag-3".to_string()), + not_modified: false, + }) + } + } + + async fn failing_fetch( + _container: ContainerReference, + _continuation: Option, + ) -> Option { + None + } + + #[tokio::test] + async fn resolves_single_range_for_full_epk_space() { + let cache = PartitionKeyRangeCache::new(); + let mut provider = + CachedTopologyProvider::new(&cache, make_container(), single_range_fetch); + + let ranges = provider + .resolve_ranges(&FeedRange::full(), PartitionRoutingRefresh::ForceRefresh) + .await + .unwrap(); + + assert_eq!(ranges.len(), 1); + assert_eq!(ranges[0].partition_key_range_id, "0"); + assert_eq!( + ranges[0].range.min_inclusive(), + &EffectivePartitionKey::min() + ); + assert_eq!( + ranges[0].range.max_exclusive(), + &EffectivePartitionKey::max() + ); + } + + #[tokio::test] + async fn resolves_split_ranges() { + let cache = PartitionKeyRangeCache::new(); + let mut provider = CachedTopologyProvider::new(&cache, make_container(), two_range_fetch); + + let ranges = provider + .resolve_ranges(&FeedRange::full(), PartitionRoutingRefresh::ForceRefresh) + .await + .unwrap(); + + assert_eq!(ranges.len(), 2); + assert_eq!(ranges[0].partition_key_range_id, "1"); + assert_eq!( + ranges[0].range.min_inclusive(), + &EffectivePartitionKey::min() + ); + assert_eq!( + ranges[0].range.max_exclusive(), + &EffectivePartitionKey::from("80") + ); + assert_eq!(ranges[1].partition_key_range_id, "2"); + assert_eq!( + ranges[1].range.min_inclusive(), + &EffectivePartitionKey::from("80") + ); + assert_eq!( + ranges[1].range.max_exclusive(), + &EffectivePartitionKey::max() + ); + } + + #[tokio::test] + async fn resolves_partial_epk_range() { + let cache = PartitionKeyRangeCache::new(); + let mut provider = CachedTopologyProvider::new(&cache, make_container(), two_range_fetch); + + let left_half = FeedRange::new( + EffectivePartitionKey::min(), + EffectivePartitionKey::from("80"), + ); + let ranges = provider + .resolve_ranges(&left_half, PartitionRoutingRefresh::ForceRefresh) + .await + .unwrap(); + + assert_eq!(ranges.len(), 1); + assert_eq!(ranges[0].partition_key_range_id, "1"); + } + + #[tokio::test] + async fn resolves_three_way_split() { + let cache = PartitionKeyRangeCache::new(); + let mut provider = CachedTopologyProvider::new(&cache, make_container(), three_range_fetch); + + let ranges = provider + .resolve_ranges(&FeedRange::full(), PartitionRoutingRefresh::ForceRefresh) + .await + .unwrap(); + + assert_eq!(ranges.len(), 3); + assert_eq!(ranges[0].partition_key_range_id, "1"); + assert_eq!(ranges[1].partition_key_range_id, "2"); + assert_eq!(ranges[2].partition_key_range_id, "3"); + } + + #[tokio::test] + async fn returns_error_when_fetch_fails() { + let cache = PartitionKeyRangeCache::new(); + let mut provider = CachedTopologyProvider::new(&cache, make_container(), failing_fetch); + + let err = provider + .resolve_ranges(&FeedRange::full(), PartitionRoutingRefresh::ForceRefresh) + .await + .unwrap_err(); + assert_eq!( + err.to_string(), + "failed to resolve partition key ranges from topology cache" + ); + } +} diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/mod.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/mod.rs index 2e7bdf123ab..bc899604699 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/src/driver/mod.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/mod.rs @@ -13,6 +13,7 @@ pub(crate) mod cache; mod cosmos_driver; +pub(crate) mod dataflow; pub(crate) mod jitter; pub(crate) mod pipeline; pub(crate) mod routing; @@ -20,6 +21,7 @@ mod runtime; pub(crate) mod transport; pub use cosmos_driver::CosmosDriver; +pub use dataflow::OperationPlan; pub use runtime::{CosmosDriverRuntime, CosmosDriverRuntimeBuilder}; /// Walks an error's `.source()` chain and joins all distinct messages into a diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/pipeline/components.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/pipeline/components.rs index ef0ccfcafcd..41161f17508 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/src/driver/pipeline/components.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/pipeline/components.rs @@ -402,6 +402,11 @@ impl TransportResult { _ => None, } } + + /// Returns true if this attempt resulted in a successful HTTP response (2xx). + pub fn is_successful(&self) -> bool { + matches!(self.outcome, TransportOutcome::Success { .. }) + } } /// The outcome of a single transport attempt. @@ -450,9 +455,14 @@ impl std::fmt::Display for TransportOutcome { impl std::fmt::Debug for TransportOutcome { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - TransportOutcome::Success { status, .. } => f + TransportOutcome::Success { + status, + cosmos_headers, + .. + } => f .debug_struct("Success") .field("status", status) + .field("cosmos_headers", &cosmos_headers) .field("body", &"...") .finish(), TransportOutcome::HttpError { diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/pipeline/operation_pipeline.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/pipeline/operation_pipeline.rs index 6ed076b89f8..335967c803c 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/src/driver/pipeline/operation_pipeline.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/pipeline/operation_pipeline.rs @@ -22,8 +22,9 @@ use crate::{ }, driver::transport::CosmosTransport, models::{ - request_header_names, AccountEndpoint, ActivityId, CosmosOperation, CosmosResponse, - Credential, DefaultConsistencyLevel, OperationType, SessionToken, SubStatusCode, + cosmos_headers::QUERY_CONTENT_TYPE, request_header_names, AccountEndpoint, ActivityId, + CosmosOperation, CosmosResponse, Credential, DefaultConsistencyLevel, + EffectivePartitionKey, OperationType, SessionToken, SubStatusCode, }, options::{ OperationOptionsView, ReadConsistencyStrategy, Region, ThroughputControlGroupSnapshot, @@ -45,6 +46,76 @@ use crate::driver::transport::{ AuthorizationContext, }; +/// Per-request overrides that take precedence over values from [`CosmosOperation`]. +/// +/// Used by the dataflow pipeline to inject routing and pagination state that +/// varies per physical partition or per page, without mutating the shared +/// `CosmosOperation`. Each field, when `Some`, emits the corresponding request +/// header in [`OperationOverrides::apply_headers`]. +#[derive(Debug, Clone, Default)] +pub(crate) struct OperationOverrides { + /// Feed range to constrain the request to (emits `x-ms-start-epk` / `x-ms-end-epk`). + pub feed_range: Option, + + /// Physical partition key range ID (emits `x-ms-documentdb-partitionkeyrangeid`). + pub partition_key_range_id: Option, + + /// Logical partition key (emits `x-ms-documentdb-partitionkey`). + pub partition_key: Option, + + /// Continuation token for pagination (emits `x-ms-continuation`). + pub continuation: Option, +} + +impl OperationOverrides { + /// Applies the override headers to the given header map. + /// + /// Headers set here take precedence over any previously-set values for + /// the same header name (they overwrite on conflict). + pub fn apply_headers( + &self, + headers: &mut azure_core::http::headers::Headers, + ) -> azure_core::Result<()> { + if let Some(feed_range) = &self.feed_range { + if feed_range.min_inclusive() != &EffectivePartitionKey::min() { + headers.insert( + HeaderName::from_static(request_header_names::START_EPK), + HeaderValue::from(feed_range.min_inclusive().as_str().to_owned()), + ); + } + if feed_range.max_exclusive() != &EffectivePartitionKey::max() { + headers.insert( + HeaderName::from_static(request_header_names::END_EPK), + HeaderValue::from(feed_range.max_exclusive().as_str().to_owned()), + ); + } + } + + if let Some(pk_range_id) = &self.partition_key_range_id { + headers.insert( + HeaderName::from_static(request_header_names::PARTITION_KEY_RANGE_ID), + HeaderValue::from(pk_range_id.clone()), + ); + } + + if let Some(pk) = &self.partition_key { + let pk_headers = pk.as_headers()?; + for (name, value) in pk_headers { + headers.insert(name, value); + } + } + + if let Some(continuation) = &self.continuation { + headers.insert( + HeaderName::from_static(request_header_names::CONTINUATION), + HeaderValue::from(continuation.clone()), + ); + } + + Ok(()) + } +} + /// Executes a Cosmos DB operation through the new pipeline architecture. /// /// This is the entry point called by `CosmosDriver::execute_operation`. @@ -56,6 +127,7 @@ use crate::driver::transport::{ #[allow(clippy::too_many_arguments)] pub(crate) async fn execute_operation_pipeline( operation: &CosmosOperation, + overrides: OperationOverrides, options: &OperationOptionsView<'_>, custom_headers: Option<&std::collections::HashMap>, location_state_store: &LocationStateStore, @@ -183,7 +255,8 @@ pub(crate) async fn execute_operation_pipeline( .flatten(), throughput_control, }; - let mut transport_request = build_transport_request(operation, custom_headers, &ctx)?; + let mut transport_request = + build_transport_request(operation, &overrides, custom_headers, &ctx)?; apply_optional_request_headers(&mut transport_request, operation, options); @@ -702,8 +775,11 @@ struct TransportRequestContext<'a> { /// Builds a `TransportRequest` from the operation and routing decision. /// /// If `resolved_session_token` is provided, it is added to the request headers. +/// Override headers from `overrides` are applied after operation headers, so they +/// take precedence. fn build_transport_request( operation: &CosmosOperation, + overrides: &OperationOverrides, custom_headers: Option<&std::collections::HashMap>, ctx: &TransportRequestContext<'_>, ) -> azure_core::Result { @@ -748,38 +824,53 @@ fn build_transport_request( ); } - // Add partition key headers - if let Some(pk) = operation.partition_key() { - let pk_headers = pk.as_headers()?; - for (name, value) in pk_headers { - headers.insert(name, value); + // Apply operation type-specific headers. + match operation.operation_type() { + OperationType::Upsert => { + headers.insert( + HeaderName::from_static(request_header_names::IS_UPSERT), + HeaderValue::from_static("true"), + ); } - } - - // Cosmos DB uses POST for both create and upsert; the service - // distinguishes them via this header. - if operation.operation_type() == OperationType::Upsert { - headers.insert( - HeaderName::from_static(request_header_names::IS_UPSERT), - HeaderValue::from_static("true"), - ); - } - - // Cosmos DB uses POST for batch (same endpoint as create/upsert); - // the service requires these headers to process the request as a batch. - if operation.operation_type() == OperationType::Batch { - headers.insert( - HeaderName::from_static(request_header_names::IS_BATCH_REQUEST), - HeaderValue::from_static("True"), - ); - headers.insert( - HeaderName::from_static(request_header_names::BATCH_ATOMIC), - HeaderValue::from_static("True"), - ); - headers.insert( - HeaderName::from_static(request_header_names::BATCH_CONTINUE_ON_ERROR), - HeaderValue::from_static("False"), - ); + OperationType::Batch => { + headers.insert( + HeaderName::from_static(request_header_names::IS_BATCH_REQUEST), + HeaderValue::from_static("True"), + ); + headers.insert( + HeaderName::from_static(request_header_names::BATCH_ATOMIC), + HeaderValue::from_static("True"), + ); + headers.insert( + HeaderName::from_static(request_header_names::BATCH_CONTINUE_ON_ERROR), + HeaderValue::from_static("False"), + ); + } + OperationType::Query | OperationType::SqlQuery => { + headers.insert( + HeaderName::from_static(request_header_names::IS_QUERY), + HeaderValue::from_static("True"), + ); + headers.insert( + azure_core::http::headers::CONTENT_TYPE, + HeaderValue::from_static(QUERY_CONTENT_TYPE), + ); + } + OperationType::QueryPlan => { + headers.insert( + HeaderName::from_static(request_header_names::IS_QUERY), + HeaderValue::from_static("True"), + ); + headers.insert( + azure_core::http::headers::CONTENT_TYPE, + HeaderValue::from_static(QUERY_CONTENT_TYPE), + ); + headers.insert( + HeaderName::from_static(request_header_names::IS_QUERY_PLAN_REQUEST), + HeaderValue::from_static("True"), + ); + } + _ => {} } // Add operation type header for fault injection rule matching @@ -796,6 +887,10 @@ fn build_transport_request( } } + // Apply overrides — these take precedence over operation-level headers + // (e.g., an override partition key replaces the operation's partition key). + overrides.apply_headers(&mut headers)?; + // Add resolved session token if let Some(token) = &ctx.resolved_session_token { headers.insert( @@ -1093,6 +1188,7 @@ mod tests { use url::Url; use super::build_transport_request; + use super::OperationOverrides; use super::TransportRequestContext; use crate::{ diagnostics::ExecutionContext, @@ -1166,7 +1262,8 @@ mod tests { throughput_control: None, }; let request = - build_transport_request(&operation, None, &ctx).expect("request should build"); + build_transport_request(&operation, &OperationOverrides::default(), None, &ctx) + .expect("request should build"); assert_eq!(request.url.path(), "/dbs"); } @@ -1187,7 +1284,8 @@ mod tests { throughput_control: None, }; let request = - build_transport_request(&operation, None, &ctx).expect("request should build"); + build_transport_request(&operation, &OperationOverrides::default(), None, &ctx) + .expect("request should build"); assert_eq!(request.url.path(), "/dbs/mydb"); } @@ -1208,7 +1306,8 @@ mod tests { throughput_control: None, }; let request = - build_transport_request(&operation, None, &ctx).expect("request should build"); + build_transport_request(&operation, &OperationOverrides::default(), None, &ctx) + .expect("request should build"); let activity_header = request .headers @@ -1233,8 +1332,12 @@ mod tests { resolved_session_token: None, throughput_control: None, }; - let request = - build_transport_request(&operation, None, &ctx).expect("request should build"); + let overrides = OperationOverrides { + partition_key: Some(PartitionKey::from("pk1")), + ..Default::default() + }; + let request = build_transport_request(&operation, &overrides, None, &ctx) + .expect("request should build"); let partition_key_header = request .headers @@ -1267,7 +1370,8 @@ mod tests { throughput_control: None, }; let request = - build_transport_request(&operation, None, &ctx).expect("request should build"); + build_transport_request(&operation, &OperationOverrides::default(), None, &ctx) + .expect("request should build"); assert_eq!( request.url.as_str(), @@ -1297,7 +1401,8 @@ mod tests { throughput_control: None, }; let request = - build_transport_request(&operation, None, &ctx).expect("request should build"); + build_transport_request(&operation, &OperationOverrides::default(), None, &ctx) + .expect("request should build"); assert_eq!( request.url.as_str(), @@ -2436,7 +2541,8 @@ mod tests { throughput_control: None, }; let request = - build_transport_request(&operation, None, &ctx).expect("request should build"); + build_transport_request(&operation, &OperationOverrides::default(), None, &ctx) + .expect("request should build"); let is_upsert = request .headers @@ -2468,7 +2574,8 @@ mod tests { throughput_control: None, }; let request = - build_transport_request(&operation, None, &ctx).expect("request should build"); + build_transport_request(&operation, &OperationOverrides::default(), None, &ctx) + .expect("request should build"); assert!( request @@ -2502,7 +2609,8 @@ mod tests { throughput_control: None, }; let request = - build_transport_request(&operation, None, &ctx).expect("request should build"); + build_transport_request(&operation, &OperationOverrides::default(), None, &ctx) + .expect("request should build"); assert_eq!( request @@ -2548,7 +2656,8 @@ mod tests { throughput_control: None, }; let request = - build_transport_request(&operation, None, &ctx).expect("request should build"); + build_transport_request(&operation, &OperationOverrides::default(), None, &ctx) + .expect("request should build"); assert!( request @@ -2585,7 +2694,9 @@ mod tests { resolved_session_token: None, throughput_control: Some(&snapshot), }; - let request = build_transport_request(&operation, None, &ctx).unwrap(); + let request = + build_transport_request(&operation, &OperationOverrides::default(), None, &ctx) + .unwrap(); let priority = request .headers @@ -2628,7 +2739,9 @@ mod tests { resolved_session_token: None, throughput_control: Some(&snapshot), }; - let request = build_transport_request(&operation, None, &ctx).unwrap(); + let request = + build_transport_request(&operation, &OperationOverrides::default(), None, &ctx) + .unwrap(); let bucket = request .headers @@ -2672,7 +2785,9 @@ mod tests { resolved_session_token: None, throughput_control: Some(&snapshot), }; - let request = build_transport_request(&operation, None, &ctx).unwrap(); + let request = + build_transport_request(&operation, &OperationOverrides::default(), None, &ctx) + .unwrap(); assert_eq!( request.headers.get_optional_str(&HeaderName::from_static( diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/pipeline/retry_evaluation.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/pipeline/retry_evaluation.rs index a0dcef34052..b8cdfea1060 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/src/driver/pipeline/retry_evaluation.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/pipeline/retry_evaluation.rs @@ -378,7 +378,8 @@ fn try_handle_retry_trigger_group( && status.sub_status() == Some(SubStatusCode::SYSTEM_RESOURCE_UNAVAILABLE); let is_service_unavailable = status.status_code() == azure_core::http::StatusCode::ServiceUnavailable; - let is_gone = status.is_gone(); + // Partition Topology changes (410 with sub-status 1009) are handled by the dataflow layer, not classified as retry triggers here. Only non-topology 410s trigger retries. + let is_gone = status.is_gone() && !status.is_partition_topology_change(); let is_request_timeout = status.status_code() == azure_core::http::StatusCode::RequestTimeout; let in_trigger_group = @@ -695,6 +696,18 @@ mod tests { } } + fn make_http_error_status(status: CosmosStatus) -> TransportResult { + TransportResult { + outcome: TransportOutcome::HttpError { + status, + headers: azure_core::http::headers::Headers::new(), + cosmos_headers: CosmosResponseHeaders::default(), + body: vec![], + request_sent: RequestSentStatus::Sent, + }, + } + } + #[test] fn success_completes() { let op = make_read_operation(); @@ -846,6 +859,55 @@ mod tests { assert!(matches!(action, OperationAction::Abort { .. })); } + #[test] + fn partition_topology_gone_aborts_for_dataflow_handling() { + let op = make_read_operation(); + let result = make_http_error_status( + CosmosStatus::new(StatusCode::Gone) + .with_sub_status(SubStatusCode::PARTITION_KEY_RANGE_GONE.value()), + ); + let state = OperationRetryState::initial(0, false, Vec::new(), 3, 1); + let endpoint = CosmosEndpoint::global( + url::Url::parse("https://test.documents.azure.com:443/").unwrap(), + ); + + let (action, effects) = evaluate_transport_result(&op, &endpoint, result, &state); + + match action { + OperationAction::Abort { status, .. } => { + assert_eq!( + status, + Some( + CosmosStatus::new(StatusCode::Gone) + .with_sub_status(SubStatusCode::PARTITION_KEY_RANGE_GONE.value()) + ) + ); + } + other => panic!("expected abort, got {other:?}"), + } + assert!(effects.is_empty()); + } + + #[test] + fn non_topology_gone_still_retries() { + let op = make_read_operation(); + let result = make_http_error_status( + CosmosStatus::new(StatusCode::Gone) + .with_sub_status(SubStatusCode::NAME_CACHE_STALE.value()), + ); + let state = OperationRetryState::initial(0, false, Vec::new(), 3, 1); + let endpoint = CosmosEndpoint::global( + url::Url::parse("https://test.documents.azure.com:443/").unwrap(), + ); + + let (action, effects) = evaluate_transport_result(&op, &endpoint, result, &state); + + assert!(matches!(action, OperationAction::FailoverRetry { .. })); + assert!(effects + .iter() + .any(|e| matches!(e, LocationEffect::MarkEndpointUnavailable { .. }))); + } + #[test] fn write_forbidden_triggers_failover_and_refresh_effect() { let op = make_create_operation(); diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/transport_pipeline.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/transport_pipeline.rs index ff0cdde4470..7be9c8e809a 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/transport_pipeline.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/transport_pipeline.rs @@ -284,7 +284,24 @@ pub(crate) async fn execute_transport_pipeline( diagnostics.set_fault_injection_evaluations(request_handle, evals); } } - tracing::debug!("transport request complete"); + tracing::debug!( + outcome = ?result.result.outcome, + "transport request complete" + ); + if result.result.is_successful() { + tracing::trace!( + ?result.result.outcome, + "transport attempt complete" + ); + } else if let TransportOutcome::HttpError { status, body, .. } = &result.result.outcome { + let body_str = String::from_utf8_lossy(body); + tracing::warn!(%status, "transport request resulted in HTTP error: {}", body_str); + } else { + tracing::warn!( + ?result.result.outcome, + "transport attempt failed" + ); + } if result.shard_id.is_some_and(|failed_shard_id| { local_connectivity_retry_count < MAX_LOCAL_CONNECTIVITY_RETRIES @@ -301,9 +318,8 @@ pub(crate) async fn execute_transport_pipeline( continue; } - let result = result.result; - // Check for 429 throttling → transport-level retry + let result = result.result; let action = evaluate_transport_retry(&result, &throttle_state); match action { ThrottleAction::Retry { delay, new_state } => { diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/lib.rs b/sdk/cosmos/azure_data_cosmos_driver/src/lib.rs index 817e677d986..8dab31c6cc8 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/src/lib.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/src/lib.rs @@ -34,6 +34,6 @@ pub mod testing; // Re-export key types at crate root pub use diagnostics::{DiagnosticsContext, ExecutionContext, RequestDiagnostics, RequestHandle}; -pub use driver::{CosmosDriver, CosmosDriverRuntime, CosmosDriverRuntimeBuilder}; +pub use driver::{CosmosDriver, CosmosDriverRuntime, CosmosDriverRuntimeBuilder, OperationPlan}; pub use models::{ActivityId, CosmosResponse, CosmosStatus, RequestCharge}; pub use options::{DiagnosticsOptions, DiagnosticsVerbosity, DriverOptions}; diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/models/continuation_token.rs b/sdk/cosmos/azure_data_cosmos_driver/src/models/continuation_token.rs new file mode 100644 index 00000000000..dcfbebef815 --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos_driver/src/models/continuation_token.rs @@ -0,0 +1,324 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//! Continuation token type for resumable Cosmos DB feed operations. +//! +//! A [`ContinuationToken`] is an opaque, durable representation of where a +//! feed operation left off. Tokens are produced by the SDK from a live +//! [`OperationPlan`](crate::OperationPlan) and consumed by +//! [`CosmosDriver::plan_operation`](crate::driver::CosmosDriver::plan_operation) +//! to build an equivalent pipeline that resumes at the same position. +//! +//! # Token format +//! +//! SDK-issued tokens start with a version prefix `c.` followed by a +//! base64url-no-pad encoded JSON document. The current version is `c1.`. +//! Tokens with a `c.` prefix where `N > 1` are returned by newer SDKs and +//! are rejected with a clear error. +//! +//! Tokens without a `c.` prefix are treated as opaque server-issued +//! continuation strings and are only valid for trivial operations +//! (single-partition or non-query operations) where the SDK can pass them +//! through unmodified. + +use base64::Engine; +use serde::{Deserialize, Serialize}; + +use crate::driver::dataflow::PipelineNodeState; + +/// Current SDK token version prefix. +const SDK_V1_PREFIX: &str = "c1."; + +/// Opaque continuation token for resuming a paginated Cosmos DB operation. +/// +/// Construct one from a string returned by an earlier query (either the +/// SDK's `to_continuation_token()` output, or — for trivial operations — a +/// raw server-side continuation string). +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ContinuationToken(String); + +impl ContinuationToken { + /// Wraps an opaque continuation string. + /// + /// No validation is performed here; the string is validated when it is + /// passed to + /// [`CosmosDriver::plan_operation`](crate::driver::CosmosDriver::plan_operation). + pub fn from_string(token: String) -> Self { + Self(token) + } + + /// Returns the underlying string form of this token. + pub fn as_str(&self) -> &str { + &self.0 + } + + /// Encodes a [`PipelineNodeState`] as a `c1.`-prefixed token. + pub(crate) fn encode_v1(state: &PipelineNodeState) -> azure_core::Result { + let json = serde_json::to_vec(state).map_err(|e| { + azure_core::Error::with_message( + azure_core::error::ErrorKind::DataConversion, + format!("failed to serialize continuation token state: {e}"), + ) + })?; + let body = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(json); + let mut out = String::with_capacity(SDK_V1_PREFIX.len() + body.len()); + out.push_str(SDK_V1_PREFIX); + out.push_str(&body); + Ok(Self(out)) + } + + /// Resolves this token into a planner-ready form. + pub(crate) fn resolve(&self) -> azure_core::Result { + if let Some(rest) = self.0.strip_prefix(SDK_V1_PREFIX) { + let json = base64::engine::general_purpose::URL_SAFE_NO_PAD + .decode(rest) + .map_err(|e| { + azure_core::Error::with_message( + azure_core::error::ErrorKind::DataConversion, + format!("continuation token has invalid base64 payload: {e}"), + ) + })?; + let state: PipelineNodeState = serde_json::from_slice(&json).map_err(|e| { + azure_core::Error::with_message( + azure_core::error::ErrorKind::DataConversion, + format!("continuation token has invalid JSON payload: {e}"), + ) + })?; + return Ok(ResolvedToken::ClientV1(state)); + } + + if let Some(version) = parse_client_version_prefix(&self.0) { + return Err(azure_core::Error::with_message( + azure_core::error::ErrorKind::DataConversion, + format!( + "continuation token uses unsupported version 'c{version}.'; \ + this SDK only understands 'c1.' tokens — upgrade to a newer SDK" + ), + )); + } + + // No client-version prefix: treat as an opaque server-issued token. + Ok(ResolvedToken::ServerOpaque(self.0.clone())) + } +} + +/// Resolved form of a [`ContinuationToken`] for use during planning. +pub(crate) enum ResolvedToken { + /// A client-issued v1 token containing a snapshot of pipeline state. + ClientV1(PipelineNodeState), + + /// An opaque server continuation string. Only valid for trivial operations. + ServerOpaque(String), +} + +// `PipelineNodeState` lives in driver internals and is not Debug-printable +// outside; provide a tiny Debug shim so test panic messages can include it. +impl std::fmt::Debug for ResolvedToken { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ResolvedToken::ClientV1(state) => write!(f, "ClientV1({state:?})"), + ResolvedToken::ServerOpaque(s) => write!(f, "ServerOpaque({s})"), + } + } +} + +/// Returns `Some(N)` if `s` starts with `c.` for some unsigned integer `N`, +/// otherwise `None`. +/// +/// The `c.` prefix is a deliberate, reserved namespace for SDK-issued +/// tokens (where `N` is the SDK's continuation-token format version). +/// Server-issued opaque continuation tokens have never been observed to start +/// with this pattern, so the SDK treats any `c.` token as SDK-versioned and +/// anything else as a server opaque token. If the server format ever changes +/// to collide with `c.`, this is the place to revisit. +fn parse_client_version_prefix(s: &str) -> Option { + let after_c = s.strip_prefix('c')?; + let dot = after_c.find('.')?; + after_c[..dot].parse::().ok() +} + +// Allow direct serde of ContinuationToken as a string (e.g. for users storing +// it in a JSON document alongside other fields). +impl Serialize for ContinuationToken { + fn serialize(&self, serializer: S) -> Result { + serializer.serialize_str(&self.0) + } +} + +impl<'de> Deserialize<'de> for ContinuationToken { + fn deserialize>(deserializer: D) -> Result { + let s = String::deserialize(deserializer)?; + Ok(Self(s)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Decodes the base64url-no-pad payload of a `c1.`-prefixed token into + /// its raw JSON bytes for inspection. + fn decode_v1_payload(token: &ContinuationToken) -> String { + let body = token + .as_str() + .strip_prefix(SDK_V1_PREFIX) + .expect("token must be c1.-prefixed"); + let bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD + .decode(body) + .expect("payload must be valid base64url-no-pad"); + String::from_utf8(bytes).expect("payload must be valid UTF-8") + } + + /// Builds a `c1.` token whose payload is the given JSON string. + fn encode_v1_payload(json: &str) -> ContinuationToken { + let body = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(json); + ContinuationToken::from_string(format!("{SDK_V1_PREFIX}{body}")) + } + + // ── Serialization ─────────────────────────────────────────────────── + + #[test] + fn encode_v1_drained_state() { + let token = ContinuationToken::encode_v1(&PipelineNodeState::Drained).unwrap(); + assert_eq!(decode_v1_payload(&token), r#"{"kind":"drained"}"#); + } + + #[test] + fn encode_v1_request_state_omits_absent_server_continuation() { + let token = ContinuationToken::encode_v1(&PipelineNodeState::Request { + server_continuation: None, + }) + .unwrap(); + assert_eq!(decode_v1_payload(&token), r#"{"kind":"request"}"#); + } + + #[test] + fn encode_v1_request_state_includes_server_continuation() { + let token = ContinuationToken::encode_v1(&PipelineNodeState::Request { + server_continuation: Some("server-token-1".to_string()), + }) + .unwrap(); + assert_eq!( + decode_v1_payload(&token), + r#"{"kind":"request","server_continuation":"server-token-1"}"#, + ); + } + + #[test] + fn encode_v1_sequential_drain_state() { + let token = ContinuationToken::encode_v1(&PipelineNodeState::SequentialDrain { + current_min_epk: "3F".to_string(), + left_most: Box::new(PipelineNodeState::Request { + server_continuation: None, + }), + }) + .unwrap(); + assert_eq!( + decode_v1_payload(&token), + r#"{"kind":"sequential_drain","current_min_epk":"3F","left_most":{"kind":"request"}}"#, + ); + } + + // ── Deserialization ───────────────────────────────────────────────── + + #[test] + fn resolve_v1_drained_state() { + let token = encode_v1_payload(r#"{"kind":"drained"}"#); + match token.resolve().unwrap() { + ResolvedToken::ClientV1(state) => assert_eq!(state, PipelineNodeState::Drained), + other => panic!("expected ClientV1, got {other:?}"), + } + } + + #[test] + fn resolve_v1_request_state_with_server_continuation() { + let token = + encode_v1_payload(r#"{"kind":"request","server_continuation":"opaque-srv-token"}"#); + match token.resolve().unwrap() { + ResolvedToken::ClientV1(state) => assert_eq!( + state, + PipelineNodeState::Request { + server_continuation: Some("opaque-srv-token".to_string()), + } + ), + other => panic!("expected ClientV1, got {other:?}"), + } + } + + #[test] + fn resolve_v1_request_state_without_server_continuation() { + let token = encode_v1_payload(r#"{"kind":"request"}"#); + match token.resolve().unwrap() { + ResolvedToken::ClientV1(state) => assert_eq!( + state, + PipelineNodeState::Request { + server_continuation: None, + } + ), + other => panic!("expected ClientV1, got {other:?}"), + } + } + + #[test] + fn resolve_v1_sequential_drain_state() { + let token = encode_v1_payload( + r#"{"kind":"sequential_drain","current_min_epk":"3F","left_most":{"kind":"request"}}"#, + ); + match token.resolve().unwrap() { + ResolvedToken::ClientV1(state) => assert_eq!( + state, + PipelineNodeState::SequentialDrain { + current_min_epk: "3F".to_string(), + left_most: Box::new(PipelineNodeState::Request { + server_continuation: None, + }), + } + ), + other => panic!("expected ClientV1, got {other:?}"), + } + } + + // ── Error and fallback paths ──────────────────────────────────────── + + #[test] + fn rejects_newer_sdk_token() { + // cspell:ignore somethingnew + let token = ContinuationToken::from_string("c2.somethingnew".to_string()); + let err = token.resolve().unwrap_err(); + assert!(matches!( + err.kind(), + azure_core::error::ErrorKind::DataConversion + )); + assert!(err.to_string().contains("c2.")); + } + + #[test] + fn server_opaque_token_when_no_prefix() { + let token = ContinuationToken::from_string("opaque-server-string".to_string()); + match token.resolve().unwrap() { + ResolvedToken::ServerOpaque(s) => assert_eq!(s, "opaque-server-string"), + other => panic!("expected ServerOpaque, got {other:?}"), + } + } + + #[test] + fn rejects_invalid_base64_in_v1_token() { + // cspell:ignore notvalid + let token = ContinuationToken::from_string("c1.!!!notvalid!!!".to_string()); + let err = token.resolve().unwrap_err(); + assert!(matches!( + err.kind(), + azure_core::error::ErrorKind::DataConversion + )); + } + + #[test] + fn rejects_invalid_json_in_v1_token() { + let token = encode_v1_payload(r#"{"kind":"unknown_variant"}"#); + let err = token.resolve().unwrap_err(); + assert!(matches!( + err.kind(), + azure_core::error::ErrorKind::DataConversion + )); + } +} diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/models/cosmos_headers.rs b/sdk/cosmos/azure_data_cosmos_driver/src/models/cosmos_headers.rs index 7feda17a020..7a03b6d29f9 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/src/models/cosmos_headers.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/src/models/cosmos_headers.rs @@ -20,14 +20,24 @@ pub(crate) mod request_header_names { pub const IF_MATCH: &str = "if-match"; pub const IF_NONE_MATCH: &str = "if-none-match"; pub const PREFER: &str = "prefer"; + pub const IS_QUERY: &str = "x-ms-documentdb-isquery"; + pub const IS_QUERY_PLAN_REQUEST: &str = "x-ms-cosmos-is-query-plan-request"; + pub const SUPPORTED_QUERY_FEATURES: &str = "x-ms-cosmos-supported-query-features"; pub const IS_UPSERT: &str = "x-ms-documentdb-is-upsert"; pub const IS_BATCH_REQUEST: &str = "x-ms-cosmos-is-batch-request"; pub const BATCH_ATOMIC: &str = "x-ms-cosmos-batch-atomic"; pub const BATCH_CONTINUE_ON_ERROR: &str = "x-ms-cosmos-batch-continue-on-error"; + pub const CONTINUATION: &str = "x-ms-continuation"; pub const OFFER_THROUGHPUT: &str = "x-ms-offer-throughput"; pub const OFFER_AUTOPILOT_SETTINGS: &str = "x-ms-cosmos-offer-autopilot-settings"; pub const PRIORITY_LEVEL: &str = "x-ms-cosmos-priority-level"; pub const THROUGHPUT_BUCKET: &str = "x-ms-cosmos-throughput-bucket"; + pub const START_EPK: &str = "x-ms-start-epk"; + pub const END_EPK: &str = "x-ms-end-epk"; + #[allow(dead_code)] // Reserved for future direct partition-key header writes. + pub const PARTITION_KEY: &str = "x-ms-documentdb-partitionkey"; + pub const PARTITION_KEY_RANGE_ID: &str = "x-ms-documentdb-partitionkeyrangeid"; + pub const MAX_ITEM_COUNT: &str = "x-ms-max-item-count"; } /// Standard Cosmos DB response header names. @@ -72,6 +82,8 @@ pub(crate) mod response_header_names { "x-ms-documentdb-collection-lazy-indexing-progress"; } +pub const QUERY_CONTENT_TYPE: &str = "application/query+json"; + /// Header names used by the fault injection framework. #[cfg(feature = "fault_injection")] pub(crate) mod fault_injection_header_names { @@ -101,6 +113,18 @@ pub struct CosmosRequestHeaders { /// /// The driver serializes this to JSON for the header value. pub offer_autopilot_settings: Option, + + /// Supported query features (`x-ms-cosmos-supported-query-features`). + /// + /// Sent on query plan requests to indicate which query capabilities the + /// client supports. The backend uses this to shape its response. + pub supported_query_features: Option, + + /// Maximum number of items the server should return per page + /// (`x-ms-max-item-count`). + /// + /// Applies to feed-style operations such as queries and read-feed. + pub max_item_count: Option, } impl CosmosRequestHeaders { @@ -149,6 +173,18 @@ impl CosmosRequestHeaders { ); } } + if let Some(features) = self.supported_query_features.as_ref() { + headers.insert( + request_header_names::SUPPORTED_QUERY_FEATURES, + HeaderValue::from(features.clone()), + ); + } + if let Some(max_item_count) = self.max_item_count { + headers.insert( + request_header_names::MAX_ITEM_COUNT, + HeaderValue::from(max_item_count.to_string()), + ); + } } } @@ -736,6 +772,8 @@ mod tests { precondition: None, offer_throughput: None, offer_autopilot_settings: None, + supported_query_features: None, + max_item_count: None, }; assert_eq!( @@ -756,6 +794,8 @@ mod tests { precondition: None, offer_throughput: None, offer_autopilot_settings: None, + supported_query_features: None, + max_item_count: None, }; let mut headers = Headers::new(); @@ -779,6 +819,8 @@ mod tests { precondition: Some(Precondition::if_match(ETag::new("etag-value-1"))), offer_throughput: None, offer_autopilot_settings: None, + supported_query_features: None, + max_item_count: None, }; let mut headers = Headers::new(); @@ -802,6 +844,8 @@ mod tests { precondition: Some(Precondition::if_none_match(ETag::new("*"))), offer_throughput: None, offer_autopilot_settings: None, + supported_query_features: None, + max_item_count: None, }; let mut headers = Headers::new(); @@ -825,6 +869,8 @@ mod tests { precondition: None, offer_throughput: None, offer_autopilot_settings: None, + supported_query_features: None, + max_item_count: None, }; let mut headers = Headers::new(); @@ -848,6 +894,8 @@ mod tests { precondition: Some(Precondition::if_match(ETag::new("etag-abc"))), offer_throughput: None, offer_autopilot_settings: None, + supported_query_features: None, + max_item_count: None, }; let mut headers = Headers::new(); @@ -870,4 +918,28 @@ mod tests { None ); } + #[test] + fn write_to_headers_emits_max_item_count() { + let cosmos_headers = CosmosRequestHeaders { + max_item_count: Some(7), + ..Default::default() + }; + let mut headers = Headers::new(); + cosmos_headers.write_to_headers(&mut headers); + assert_eq!( + headers.get_optional_str(&HeaderName::from_static("x-ms-max-item-count")), + Some("7") + ); + } + + #[test] + fn write_to_headers_omits_max_item_count_when_none() { + let cosmos_headers = CosmosRequestHeaders::default(); + let mut headers = Headers::new(); + cosmos_headers.write_to_headers(&mut headers); + assert_eq!( + headers.get_optional_str(&HeaderName::from_static("x-ms-max-item-count")), + None + ); + } } diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/models/cosmos_operation.rs b/sdk/cosmos/azure_data_cosmos_driver/src/models/cosmos_operation.rs index 37473a1d89e..3007a4243d2 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/src/models/cosmos_operation.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/src/models/cosmos_operation.rs @@ -5,7 +5,8 @@ use crate::models::{ AccountReference, ContainerReference, CosmosRequestHeaders, CosmosResourceReference, - DatabaseReference, ItemReference, OperationType, PartitionKey, Precondition, ResourceType, + DatabaseReference, ItemReference, OperationTarget, OperationType, PartitionKey, Precondition, + ResourceType, }; use std::borrow::Cow; @@ -48,7 +49,7 @@ use std::borrow::Cow; /// // 3. Build and execute item operations /// let item = ItemReference::from_name(&container, PartitionKey::from("pk1"), "doc1"); /// let result = driver -/// .execute_operation(CosmosOperation::read_item(item), OperationOptions::default()) +/// .execute_point_operation(CosmosOperation::read_item(item), OperationOptions::default()) /// .await?; /// # Ok(()) /// # } @@ -62,8 +63,8 @@ pub struct CosmosOperation { resource_type: ResourceType, /// Reference to the resource being operated on. resource_reference: CosmosResourceReference, - /// Optional partition key for data plane operations. - partition_key: Option, + /// Describes how the operation targets the partition key space. + target: OperationTarget, /// Additional request headers to include in the request. request_headers: CosmosRequestHeaders, /// Optional request body (raw bytes, schema-agnostic). @@ -111,9 +112,9 @@ impl CosmosOperation { self.resource_reference.container() } - /// Returns the partition key, if set. - pub fn partition_key(&self) -> Option<&PartitionKey> { - self.partition_key.as_ref() + /// Returns the operation target. + pub fn target(&self) -> &OperationTarget { + &self.target } /// Returns the request headers. @@ -126,12 +127,6 @@ impl CosmosOperation { self.body.as_deref() } - /// Sets the partition key for the operation. - pub fn with_partition_key(mut self, partition_key: impl Into) -> Self { - self.partition_key = Some(partition_key.into()); - self - } - /// Sets request headers for the operation. pub fn with_request_headers(mut self, headers: CosmosRequestHeaders) -> Self { self.request_headers = headers; @@ -153,6 +148,15 @@ impl CosmosOperation { self } + /// Sets the maximum number of items the server should return per page + /// (the `x-ms-max-item-count` request header). + /// + /// Applies to feed-style operations such as queries and read-feed. + pub fn with_max_item_count(mut self, max_item_count: u32) -> Self { + self.request_headers.max_item_count = Some(max_item_count); + self + } + /// Sets the precondition for optimistic concurrency control. pub fn with_precondition(mut self, precondition: Precondition) -> Self { self.request_headers.precondition = Some(precondition); @@ -172,18 +176,23 @@ impl CosmosOperation { // ===== Factory Methods ===== - /// Creates a new operation with the specified type and resource reference. + /// Creates a new operation with the specified type, resource reference, and target. fn new( operation_type: OperationType, resource_reference: impl Into, + target: OperationTarget, ) -> Self { let resource_reference = resource_reference.into(); let resource_type = resource_reference.resource_type(); + debug_assert!( + !resource_type.is_partitioned(operation_type) || target.has_partition_reference(), + "Attempted to create a partitioned operation without an OperationTarget specifying the partitions to access" + ); Self { operation_type, resource_type, resource_reference, - partition_key: None, + target, request_headers: CosmosRequestHeaders::new(), body: None, } @@ -216,7 +225,7 @@ impl CosmosOperation { let resource_ref: CosmosResourceReference = CosmosResourceReference::from(account) .with_resource_type(ResourceType::Database) .into_feed_reference(); - Self::new(OperationType::Create, resource_ref) + Self::new(OperationType::Create, resource_ref, OperationTarget::None) } /// Reads (lists) all databases in the account. @@ -226,7 +235,7 @@ impl CosmosOperation { let resource_ref = Into::::into(account) .with_resource_type(ResourceType::Database) .into_feed_reference(); - Self::new(OperationType::ReadFeed, resource_ref) + Self::new(OperationType::ReadFeed, resource_ref, OperationTarget::None) } /// Queries databases in the account. @@ -236,7 +245,7 @@ impl CosmosOperation { let resource_ref: CosmosResourceReference = CosmosResourceReference::from(account) .with_resource_type(ResourceType::Database) .into_feed_reference(); - Self::new(OperationType::Query, resource_ref) + Self::new(OperationType::Query, resource_ref, OperationTarget::None) } /// Deletes a database. @@ -259,7 +268,7 @@ impl CosmosOperation { /// ``` pub fn delete_database(database: DatabaseReference) -> Self { let resource_ref: CosmosResourceReference = database.into(); - Self::new(OperationType::Delete, resource_ref) + Self::new(OperationType::Delete, resource_ref, OperationTarget::None) } /// Reads a database's properties from the service. @@ -268,7 +277,7 @@ impl CosmosOperation { /// the system-managed `_rid`, `_ts`, and `_etag`. pub fn read_database(database: DatabaseReference) -> Self { let resource_ref: CosmosResourceReference = database.into(); - Self::new(OperationType::Read, resource_ref) + Self::new(OperationType::Read, resource_ref, OperationTarget::None) } /// Creates a container in a database. @@ -299,7 +308,7 @@ impl CosmosOperation { let resource_ref: CosmosResourceReference = CosmosResourceReference::from(database) .with_resource_type(ResourceType::DocumentCollection) .into_feed_reference(); - Self::new(OperationType::Create, resource_ref) + Self::new(OperationType::Create, resource_ref, OperationTarget::None) } /// Reads (lists) all containers in a database. @@ -309,7 +318,7 @@ impl CosmosOperation { let resource_ref: CosmosResourceReference = CosmosResourceReference::from(database) .with_resource_type(ResourceType::DocumentCollection) .into_feed_reference(); - Self::new(OperationType::ReadFeed, resource_ref) + Self::new(OperationType::ReadFeed, resource_ref, OperationTarget::None) } /// Queries containers in a database. @@ -319,7 +328,7 @@ impl CosmosOperation { let resource_ref: CosmosResourceReference = CosmosResourceReference::from(database) .with_resource_type(ResourceType::DocumentCollection) .into_feed_reference(); - Self::new(OperationType::Query, resource_ref) + Self::new(OperationType::Query, resource_ref, OperationTarget::None) } /// Deletes a container. @@ -344,7 +353,7 @@ impl CosmosOperation { /// let container = driver.resolve_container("my-database", "my-container").await?; /// /// let result = driver - /// .execute_operation( + /// .execute_point_operation( /// CosmosOperation::delete_container(container), /// OperationOptions::default(), /// ) @@ -354,7 +363,7 @@ impl CosmosOperation { /// ``` pub fn delete_container(container: ContainerReference) -> Self { let resource_ref: CosmosResourceReference = container.into(); - Self::new(OperationType::Delete, resource_ref) + Self::new(OperationType::Delete, resource_ref, OperationTarget::None) } /// Replaces a container's properties. @@ -362,7 +371,7 @@ impl CosmosOperation { /// Use `with_body()` to provide the updated container properties JSON. pub fn replace_container(container: ContainerReference) -> Self { let resource_ref: CosmosResourceReference = container.into(); - Self::new(OperationType::Replace, resource_ref) + Self::new(OperationType::Replace, resource_ref, OperationTarget::None) } /// Reads a container's properties from the service. @@ -371,7 +380,7 @@ impl CosmosOperation { /// including system-managed properties like `_rid`, `_ts`, and `_etag`. pub fn read_container(container: ContainerReference) -> Self { let resource_ref: CosmosResourceReference = container.into(); - Self::new(OperationType::Read, resource_ref) + Self::new(OperationType::Read, resource_ref, OperationTarget::None) } /// Reads a container's properties by database and container name. @@ -386,7 +395,7 @@ impl CosmosOperation { let resource_ref: CosmosResourceReference = CosmosResourceReference::from(database) .with_resource_type(ResourceType::DocumentCollection) .with_name(container_name.into()); - Self::new(OperationType::Read, resource_ref) + Self::new(OperationType::Read, resource_ref, OperationTarget::None) } /// Reads a container's properties by database RID and container RID. @@ -397,7 +406,7 @@ impl CosmosOperation { let resource_ref: CosmosResourceReference = CosmosResourceReference::from(database) .with_resource_type(ResourceType::DocumentCollection) .with_rid(container_rid.into()); - Self::new(OperationType::Read, resource_ref) + Self::new(OperationType::Read, resource_ref, OperationTarget::None) } // ===== Data Plane Factory Methods ===== @@ -429,7 +438,7 @@ impl CosmosOperation { /// /// let item = ItemReference::from_name(&container, PartitionKey::from("pk-value"), "doc1"); /// let result = driver - /// .execute_operation( + /// .execute_point_operation( /// CosmosOperation::create_item(item) /// .with_body(br#"{"id": "doc1", "pk": "pk-value", "data": "hello"}"#.to_vec()), /// OperationOptions::default(), @@ -440,7 +449,11 @@ impl CosmosOperation { /// ``` pub fn create_item(item: ItemReference) -> Self { let partition_key = item.partition_key().clone(); - Self::new(OperationType::Create, item).with_partition_key(partition_key) + Self::new( + OperationType::Create, + item, + OperationTarget::PartitionKey(partition_key), + ) } /// Reads an item (document) from a container. @@ -470,14 +483,18 @@ impl CosmosOperation { /// /// let item = ItemReference::from_name(&container, PartitionKey::from("pk-value"), "doc1"); /// let result = driver - /// .execute_operation(CosmosOperation::read_item(item), OperationOptions::default()) + /// .execute_point_operation(CosmosOperation::read_item(item), OperationOptions::default()) /// .await?; /// # Ok(()) /// # } /// ``` pub fn read_item(item: ItemReference) -> Self { let partition_key = item.partition_key().clone(); - Self::new(OperationType::Read, item).with_partition_key(partition_key) + Self::new( + OperationType::Read, + item, + OperationTarget::PartitionKey(partition_key), + ) } /// Deletes an item (document) from a container. @@ -486,7 +503,11 @@ impl CosmosOperation { /// providing all the information needed for the operation. pub fn delete_item(item: ItemReference) -> Self { let partition_key = item.partition_key().clone(); - Self::new(OperationType::Delete, item).with_partition_key(partition_key) + Self::new( + OperationType::Delete, + item, + OperationTarget::PartitionKey(partition_key), + ) } /// Executes a transactional batch of operations against a single partition. @@ -498,7 +519,11 @@ impl CosmosOperation { let resource_ref: CosmosResourceReference = CosmosResourceReference::from(container) .with_resource_type(ResourceType::Document) .into_feed_reference(); - Self::new(OperationType::Batch, resource_ref).with_partition_key(partition_key) + Self::new( + OperationType::Batch, + resource_ref, + OperationTarget::PartitionKey(partition_key), + ) } /// Upserts (creates or replaces) an item (document) in a container. @@ -509,7 +534,11 @@ impl CosmosOperation { /// If an item with the same ID exists, it will be replaced; otherwise, a new item is created. pub fn upsert_item(item: ItemReference) -> Self { let partition_key = item.partition_key().clone(); - Self::new(OperationType::Upsert, item).with_partition_key(partition_key) + Self::new( + OperationType::Upsert, + item, + OperationTarget::PartitionKey(partition_key), + ) } /// Replaces an existing item (document) in a container. @@ -519,7 +548,11 @@ impl CosmosOperation { /// Use `with_body()` to provide the new document JSON. pub fn replace_item(item: ItemReference) -> Self { let partition_key = item.partition_key().clone(); - Self::new(OperationType::Replace, item).with_partition_key(partition_key) + Self::new( + OperationType::Replace, + item, + OperationTarget::PartitionKey(partition_key), + ) } /// Reads (lists) all items within a single partition. @@ -530,7 +563,11 @@ impl CosmosOperation { let resource_ref: CosmosResourceReference = CosmosResourceReference::from(container) .with_resource_type(ResourceType::Document) .into_feed_reference(); - Self::new(OperationType::ReadFeed, resource_ref).with_partition_key(partition_key) + Self::new( + OperationType::ReadFeed, + resource_ref, + OperationTarget::PartitionKey(partition_key), + ) } /// Reads (lists) all items across all partitions. @@ -544,33 +581,57 @@ impl CosmosOperation { let resource_ref: CosmosResourceReference = CosmosResourceReference::from(container) .with_resource_type(ResourceType::Document) .into_feed_reference(); - Self::new(OperationType::ReadFeed, resource_ref) + Self::new( + OperationType::ReadFeed, + resource_ref, + OperationTarget::FeedRange(crate::models::FeedRange::full()), + ) } - /// Queries items within a single partition. + /// Queries items in a container. + /// + /// The `target` determines partition scope: use + /// [`OperationTarget::PartitionKey`] for single-partition queries, or + /// [`OperationTarget::FeedRange`] for cross-partition queries. /// /// Use `with_body()` to provide the query JSON. - /// This is more efficient than cross-partition queries. - pub fn query_items(container: ContainerReference, partition_key: PartitionKey) -> Self { + pub fn query_items(container: ContainerReference, target: OperationTarget) -> Self { let resource_ref: CosmosResourceReference = CosmosResourceReference::from(container) .with_resource_type(ResourceType::Document) .into_feed_reference(); - Self::new(OperationType::Query, resource_ref).with_partition_key(partition_key) + Self::new(OperationType::Query, resource_ref, target) } - /// Queries items across all partitions. + /// Creates a query plan request for a container. /// - /// Use `with_body()` to provide the query JSON. + /// The query plan request is sent to the backend gateway to obtain + /// execution metadata (partition targeting, rewritten query, etc.) + /// before issuing the actual cross-partition query. /// - /// This is equivalent to calling `query_items()` with [`PartitionKey::EMPTY`], - /// which causes the `x-ms-documentdb-query-enablecrosspartition` header to be - /// emitted by the pipeline. + /// Use `with_body()` to provide the query JSON (same as the original query). + pub(crate) fn query_plan(container: ContainerReference) -> Self { + let resource_ref: CosmosResourceReference = CosmosResourceReference::from(container) + .with_resource_type(ResourceType::Document) + .into_feed_reference(); + let mut headers = CosmosRequestHeaders::new(); + headers.supported_query_features = Some(String::new()); + Self::new( + OperationType::QueryPlan, + resource_ref, + OperationTarget::None, + ) + .with_request_headers(headers) + } + + /// Creates a read-feed request for partition key ranges in a container. /// - /// **Warning:** Cross-partition queries are inherently less efficient than - /// single-partition queries. Use `query_items()` with a partition key - /// when possible. - pub fn query_items_cross_partition(container: ContainerReference) -> Self { - Self::query_items(container, PartitionKey::EMPTY) + /// Used to populate the partition key range cache for topology resolution. + #[allow(dead_code)] // Reserved for an upcoming pk-range cache refresh path. + pub(crate) fn read_partition_key_ranges(container: ContainerReference) -> Self { + let resource_ref: CosmosResourceReference = CosmosResourceReference::from(container) + .with_resource_type(ResourceType::PartitionKeyRange) + .into_feed_reference(); + Self::new(OperationType::ReadFeed, resource_ref, OperationTarget::None) } /// Reads (lists) all partition key ranges for a container. @@ -588,7 +649,7 @@ impl CosmosOperation { let resource_ref: CosmosResourceReference = CosmosResourceReference::from(container) .with_resource_type(ResourceType::PartitionKeyRange) .into_feed_reference(); - Self::new(OperationType::ReadFeed, resource_ref) + Self::new(OperationType::ReadFeed, resource_ref, OperationTarget::None) } /// Returns true if this is a read-only operation. @@ -601,6 +662,21 @@ impl CosmosOperation { self.operation_type.is_idempotent() } + /// Returns true if this operation can be planned with a single-node pipeline. + /// + /// An operation is "trivial" when it does not require fan-out across multiple + /// physical partitions. This includes all non-query operations and queries + /// that target a specific logical partition key (single-partition queries) + /// OR queries against a non-partitioned resource (Databases, Containers, Offers, etc.). + /// + /// Cross-partition queries (those targeting a [`FeedRange`](crate::models::FeedRange)) + /// are **not** trivial and require a backend query plan to determine the + /// fan-out strategy. + pub fn is_trivial(&self) -> bool { + self.operation_type != OperationType::Query + || !matches!(self.target(), OperationTarget::FeedRange(_)) + } + // -- Offer operations -- /// Queries offers in the account. @@ -611,7 +687,7 @@ impl CosmosOperation { let resource_ref: CosmosResourceReference = CosmosResourceReference::from(account) .with_resource_type(ResourceType::Offer) .into_feed_reference(); - Self::new(OperationType::Query, resource_ref) + Self::new(OperationType::Query, resource_ref, OperationTarget::None) } /// Reads a specific offer by its ID. @@ -621,7 +697,7 @@ impl CosmosOperation { let resource_ref: CosmosResourceReference = CosmosResourceReference::from(account) .with_resource_type(ResourceType::Offer) .with_rid(offer_id.into()); - Self::new(OperationType::Read, resource_ref) + Self::new(OperationType::Read, resource_ref, OperationTarget::None) } /// Replaces a specific offer by its ID. @@ -635,7 +711,7 @@ impl CosmosOperation { let resource_ref: CosmosResourceReference = CosmosResourceReference::from(account) .with_resource_type(ResourceType::Offer) .with_rid(offer_id.into()); - Self::new(OperationType::Replace, resource_ref) + Self::new(OperationType::Replace, resource_ref, OperationTarget::None) } } @@ -681,10 +757,14 @@ mod tests { #[test] fn create_operation() { - let item_ref = - ItemReference::from_name(&test_container(), PartitionKey::from("pk1"), "doc1"); + let pk = PartitionKey::from("pk1"); + let item_ref = ItemReference::from_name(&test_container(), pk.clone(), "doc1"); let resource_ref: CosmosResourceReference = item_ref.into(); - let op = CosmosOperation::new(OperationType::Create, resource_ref); + let op = CosmosOperation::new( + OperationType::Create, + resource_ref, + OperationTarget::PartitionKey(pk), + ); assert_eq!(op.operation_type(), OperationType::Create); assert_eq!(op.resource_type(), ResourceType::Document); @@ -694,10 +774,14 @@ mod tests { #[test] fn read_operation() { - let item_ref = - ItemReference::from_name(&test_container(), PartitionKey::from("pk1"), "doc1"); + let pk = PartitionKey::from("pk1"); + let item_ref = ItemReference::from_name(&test_container(), pk.clone(), "doc1"); let resource_ref: CosmosResourceReference = item_ref.into(); - let op = CosmosOperation::new(OperationType::Read, resource_ref); + let op = CosmosOperation::new( + OperationType::Read, + resource_ref, + OperationTarget::PartitionKey(pk), + ); assert_eq!(op.operation_type(), OperationType::Read); assert_eq!(op.resource_type(), ResourceType::Document); @@ -710,29 +794,41 @@ mod tests { let item_ref = ItemReference::from_name(&test_container(), PartitionKey::from("pk1"), "doc1"); let resource_ref: CosmosResourceReference = item_ref.into(); - let op = CosmosOperation::new(OperationType::Read, resource_ref) - .with_partition_key(PartitionKey::from("pk1")); + let op = CosmosOperation::new( + OperationType::Read, + resource_ref, + OperationTarget::PartitionKey(PartitionKey::from("pk1")), + ); - assert!(op.partition_key().is_some()); + assert!(matches!(op.target(), OperationTarget::PartitionKey(_))); } #[test] fn operation_with_body() { - let item_ref = - ItemReference::from_name(&test_container(), PartitionKey::from("pk1"), "doc1"); + let pk = PartitionKey::from("pk1"); + let item_ref = ItemReference::from_name(&test_container(), pk.clone(), "doc1"); let resource_ref: CosmosResourceReference = item_ref.into(); let body = b"{\"id\":\"doc1\"}".to_vec(); - let op = CosmosOperation::new(OperationType::Create, resource_ref).with_body(body.clone()); + let op = CosmosOperation::new( + OperationType::Create, + resource_ref, + OperationTarget::PartitionKey(pk), + ) + .with_body(body.clone()); assert_eq!(op.body(), Some(body.as_slice())); } #[test] fn replace_is_idempotent() { - let item_ref = - ItemReference::from_name(&test_container(), PartitionKey::from("pk1"), "doc1"); + let pk = PartitionKey::from("pk1"); + let item_ref = ItemReference::from_name(&test_container(), pk.clone(), "doc1"); let resource_ref: CosmosResourceReference = item_ref.into(); - let op = CosmosOperation::new(OperationType::Replace, resource_ref); + let op = CosmosOperation::new( + OperationType::Replace, + resource_ref, + OperationTarget::PartitionKey(pk), + ); assert!(!op.is_read_only()); assert!(op.is_idempotent()); @@ -740,12 +836,27 @@ mod tests { #[test] fn upsert_is_not_idempotent() { - let item_ref = - ItemReference::from_name(&test_container(), PartitionKey::from("pk1"), "doc1"); + let pk = PartitionKey::from("pk1"); + let item_ref = ItemReference::from_name(&test_container(), pk.clone(), "doc1"); let resource_ref: CosmosResourceReference = item_ref.into(); - let op = CosmosOperation::new(OperationType::Upsert, resource_ref); + let op = CosmosOperation::new( + OperationType::Upsert, + resource_ref, + OperationTarget::PartitionKey(pk), + ); assert!(!op.is_read_only()); assert!(!op.is_idempotent()); } + + /// Creating a partitioned operation without a partition target panics in + /// debug builds and silently proceeds in release builds. + #[test] + #[cfg_attr(debug_assertions, should_panic)] + fn rejects_partitioned_operation_without_target() { + let item_ref = + ItemReference::from_name(&test_container(), PartitionKey::from("pk1"), "doc1"); + let resource_ref: CosmosResourceReference = item_ref.into(); + let _op = CosmosOperation::new(OperationType::Create, resource_ref, OperationTarget::None); + } } diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/models/cosmos_status.rs b/sdk/cosmos/azure_data_cosmos_driver/src/models/cosmos_status.rs index ec9f60b36fe..7d6365f40c6 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/src/models/cosmos_status.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/src/models/cosmos_status.rs @@ -1288,6 +1288,19 @@ impl CosmosStatus { && self.sub_status == Some(SubStatusCode::PARTITION_KEY_RANGE_GONE) } + /// Returns `true` if this is an HTTP 410 caused by partition topology changing. + pub(crate) fn is_partition_topology_change(&self) -> bool { + u16::from(self.status_code) == 410 + && matches!( + self.sub_status, + Some( + SubStatusCode::PARTITION_KEY_RANGE_GONE + | SubStatusCode::COMPLETING_SPLIT + | SubStatusCode::COMPLETING_PARTITION_MIGRATION + ) + ) + } + /// Returns `true` if this indicates a transport-generated 503 (client-side). pub fn is_transport_generated_503(&self) -> bool { u16::from(self.status_code) == 503 diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/models/feed_range.rs b/sdk/cosmos/azure_data_cosmos_driver/src/models/feed_range.rs new file mode 100644 index 00000000000..3b731cad969 --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos_driver/src/models/feed_range.rs @@ -0,0 +1,388 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//! Feed range type for the Cosmos DB driver. +//! +//! A [`FeedRange`] represents a contiguous range of the effective partition key (EPK) space. +//! It is used by the dataflow pipeline to target operations at one or more physical partitions. +//! +//! Feed ranges can also be serialized to base64-encoded JSON for cross-SDK storage and transport. + +use azure_core::{error::ErrorKind, fmt::SafeDebug}; +use base64::Engine; +use serde::{Deserialize, Serialize}; +use std::{cmp, fmt, str::FromStr}; + +use crate::models::effective_partition_key::EffectivePartitionKey; +use crate::models::partition_key_range::PartitionKeyRange; + +/// A contiguous range of the effective partition key space. +/// +/// Defined by `[min_inclusive, max_exclusive)` EPK boundaries. A `FeedRange` may +/// map to one or several physical partitions depending on the current partition +/// topology. +/// +/// Use [`FeedRange::full()`] for the entire key space (`""..FF`). +#[derive(Clone, SafeDebug, PartialEq, Eq, Hash)] +#[safe(true)] +pub struct FeedRange { + min_inclusive: EffectivePartitionKey, + max_exclusive: EffectivePartitionKey, +} + +#[derive(Serialize, Deserialize)] +struct FeedRangeJson { + #[serde(rename = "Range")] + range: RangeJson, +} + +#[derive(Serialize, Deserialize)] +struct RangeJson { + min: String, + max: String, + #[serde(rename = "isMinInclusive")] + is_min_inclusive: bool, + #[serde(rename = "isMaxInclusive")] + is_max_inclusive: bool, +} + +impl FeedRange { + /// Creates a feed range from explicit EPK bounds. + pub fn new(min_inclusive: EffectivePartitionKey, max_exclusive: EffectivePartitionKey) -> Self { + Self { + min_inclusive, + max_exclusive, + } + } + + /// Creates a feed range covering the entire partition key space (`""..FF`). + pub fn full() -> Self { + Self { + min_inclusive: EffectivePartitionKey::min(), + max_exclusive: EffectivePartitionKey::max(), + } + } + + /// Returns the inclusive lower bound of this range. + pub fn min_inclusive(&self) -> &EffectivePartitionKey { + &self.min_inclusive + } + + /// Returns the exclusive upper bound of this range. + pub fn max_exclusive(&self) -> &EffectivePartitionKey { + &self.max_exclusive + } + + /// Returns `true` if this feed range is entirely contained within `other`. + pub fn is_subset_of(&self, other: &FeedRange) -> bool { + other.min_inclusive <= self.min_inclusive && other.max_exclusive >= self.max_exclusive + } + + /// Returns `true` if this feed range and `other` share any portion of the EPK space. + /// + /// Two feed ranges overlap when one starts before the other ends and vice versa. + pub fn overlaps(&self, other: &FeedRange) -> bool { + self.min_inclusive < other.max_exclusive && other.min_inclusive < self.max_exclusive + } + + /// Returns `true` if this feed range can be combined with `other`. + /// + /// Two ranges can be combined when they overlap or are adjacent + /// (one's max equals the other's min). + pub fn can_merge(&self, other: &FeedRange) -> bool { + self.max_exclusive >= other.min_inclusive && other.max_exclusive >= self.min_inclusive + } + + /// Combines this feed range with `other` into a bounding range. + pub fn merge_with(&self, other: &FeedRange) -> FeedRange { + debug_assert!( + self.can_merge(other), + "merge_with called on disjoint ranges" + ); + FeedRange { + min_inclusive: cmp::min(self.min_inclusive.clone(), other.min_inclusive.clone()), + max_exclusive: cmp::max(self.max_exclusive.clone(), other.max_exclusive.clone()), + } + } + + fn to_json(&self) -> FeedRangeJson { + FeedRangeJson { + range: RangeJson { + min: self.min_inclusive.as_str().to_owned(), + max: self.max_exclusive.as_str().to_owned(), + is_min_inclusive: true, + is_max_inclusive: false, + }, + } + } + + fn from_json(json: FeedRangeJson) -> azure_core::Result { + if !json.range.is_min_inclusive || json.range.is_max_inclusive { + return Err(azure_core::Error::with_message( + ErrorKind::DataConversion, + "feed range must have [min, max) semantics (isMinInclusive=true, isMaxInclusive=false)", + )); + } + + let min = EffectivePartitionKey::from(json.range.min); + let max = EffectivePartitionKey::from(json.range.max); + + if min > max { + return Err(azure_core::Error::with_message( + ErrorKind::DataConversion, + "feed range min must be less than or equal to max", + )); + } + + Ok(Self { + min_inclusive: min, + max_exclusive: max, + }) + } +} + +impl TryFrom<&PartitionKeyRange> for FeedRange { + type Error = azure_core::Error; + + /// Creates a `FeedRange` from a driver `PartitionKeyRange`. + /// + /// Partition key ranges from the service always use `[min, max)` semantics + /// (min inclusive, max exclusive). Returns an error if the range is inverted. + fn try_from(pkr: &PartitionKeyRange) -> Result { + if pkr.min_inclusive > pkr.max_exclusive { + return Err(azure_core::Error::with_message( + ErrorKind::DataConversion, + "partition key range min_inclusive must be <= max_exclusive", + )); + } + + Ok(Self { + min_inclusive: EffectivePartitionKey::from(pkr.min_inclusive.as_str()), + max_exclusive: EffectivePartitionKey::from(pkr.max_exclusive.as_str()), + }) + } +} + +impl fmt::Display for FeedRange { + /// Formats this feed range as a base64-encoded JSON string. + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let json_str = serde_json::to_string(&self.to_json()).map_err(|_| fmt::Error)?; + let encoded = base64::engine::general_purpose::STANDARD.encode(json_str.as_bytes()); + f.write_str(&encoded) + } +} + +impl FromStr for FeedRange { + type Err = azure_core::Error; + + /// Parses a feed range from a base64-encoded JSON string. + fn from_str(s: &str) -> Result { + let decoded_bytes = base64::engine::general_purpose::STANDARD + .decode(s) + .map_err(|e| azure_core::Error::new(ErrorKind::DataConversion, e))?; + + let json: FeedRangeJson = serde_json::from_slice(&decoded_bytes) + .map_err(|e| azure_core::Error::new(ErrorKind::DataConversion, e))?; + + Self::from_json(json) + } +} + +impl Serialize for FeedRange { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + self.to_json().serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for FeedRange { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let json = FeedRangeJson::deserialize(deserializer)?; + Self::from_json(json).map_err(|e| serde::de::Error::custom(e.to_string())) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn full_range() { + let full = FeedRange::full(); + assert_eq!(full.min_inclusive().as_str(), ""); + assert_eq!(full.max_exclusive().as_str(), "FF"); + } + + #[test] + fn is_subset_of_full() { + let full = FeedRange::full(); + let sub = FeedRange::new( + EffectivePartitionKey::from("00"), + EffectivePartitionKey::from("80"), + ); + assert!(sub.is_subset_of(&full)); + assert!(!full.is_subset_of(&sub)); + } + + #[test] + fn is_subset_of_self() { + let range = FeedRange::new( + EffectivePartitionKey::from("20"), + EffectivePartitionKey::from("80"), + ); + assert!(range.is_subset_of(&range)); + } + + #[test] + fn overlaps_basic() { + let a = FeedRange::new( + EffectivePartitionKey::from("00"), + EffectivePartitionKey::from("50"), + ); + let b = FeedRange::new( + EffectivePartitionKey::from("30"), + EffectivePartitionKey::from("80"), + ); + assert!(a.overlaps(&b)); + assert!(b.overlaps(&a)); + } + + #[test] + fn overlaps_adjacent_no_overlap() { + let a = FeedRange::new( + EffectivePartitionKey::from("00"), + EffectivePartitionKey::from("50"), + ); + let b = FeedRange::new( + EffectivePartitionKey::from("50"), + EffectivePartitionKey::from("FF"), + ); + assert!(!a.overlaps(&b)); + assert!(!b.overlaps(&a)); + } + + #[test] + fn overlaps_disjoint() { + let a = FeedRange::new( + EffectivePartitionKey::from("00"), + EffectivePartitionKey::from("30"), + ); + let b = FeedRange::new( + EffectivePartitionKey::from("50"), + EffectivePartitionKey::from("FF"), + ); + assert!(!a.overlaps(&b)); + assert!(!b.overlaps(&a)); + } + + #[test] + fn can_merge_adjacent() { + let a = FeedRange::new( + EffectivePartitionKey::from("00"), + EffectivePartitionKey::from("50"), + ); + let b = FeedRange::new( + EffectivePartitionKey::from("50"), + EffectivePartitionKey::from("FF"), + ); + + assert!(a.can_merge(&b)); + assert!(b.can_merge(&a)); + } + + #[test] + fn merge_with_bounds() { + let a = FeedRange::new( + EffectivePartitionKey::from("00"), + EffectivePartitionKey::from("50"), + ); + let b = FeedRange::new( + EffectivePartitionKey::from("30"), + EffectivePartitionKey::from("FF"), + ); + + let merged = a.merge_with(&b); + assert_eq!(merged.min_inclusive().as_str(), "00"); + assert_eq!(merged.max_exclusive().as_str(), "FF"); + } + + #[test] + fn display_round_trip() { + let range = FeedRange::new( + EffectivePartitionKey::from("3FFFFFFFFFFF"), + EffectivePartitionKey::from("7FFFFFFFFFFF"), + ); + + let serialized = range.to_string(); + let parsed: FeedRange = serialized.parse().unwrap(); + + assert_eq!(parsed, range); + } + + #[test] + fn serde_json_round_trip() { + let range = FeedRange::new( + EffectivePartitionKey::from(""), + EffectivePartitionKey::from("FF"), + ); + + let json = serde_json::to_string(&range).unwrap(); + let parsed: FeedRange = serde_json::from_str(&json).unwrap(); + + assert_eq!(parsed, range); + } + + #[test] + fn try_from_partition_key_range() { + let pkr = PartitionKeyRange::new("0".to_string(), "".to_string(), "FF".to_string()); + let feed_range = FeedRange::try_from(&pkr).unwrap(); + + assert_eq!(feed_range.min_inclusive().as_str(), ""); + assert_eq!(feed_range.max_exclusive().as_str(), "FF"); + } + + #[test] + fn from_str_invalid_base64() { + assert!("not-valid-base64!!!".parse::().is_err()); + } + + #[test] + fn from_str_invalid_json() { + let encoded = base64::engine::general_purpose::STANDARD.encode(b"not json"); + assert!(encoded.parse::().is_err()); + } + + #[test] + fn from_str_rejects_max_inclusive() { + let json = r#"{"Range":{"min":"","max":"FF","isMinInclusive":true,"isMaxInclusive":true}}"#; + let encoded = base64::engine::general_purpose::STANDARD.encode(json.as_bytes()); + assert!(encoded.parse::().is_err()); + } + + #[test] + fn serde_rejects_min_not_inclusive() { + let json = + r#"{"Range":{"min":"","max":"FF","isMinInclusive":false,"isMaxInclusive":false}}"#; + assert!(serde_json::from_str::(json).is_err()); + } + + #[test] + fn from_str_rejects_inverted_range() { + let json = + r#"{"Range":{"min":"FF","max":"","isMinInclusive":true,"isMaxInclusive":false}}"#; + let encoded = base64::engine::general_purpose::STANDARD.encode(json.as_bytes()); + assert!(encoded.parse::().is_err()); + } + + #[test] + fn serde_rejects_inverted_range() { + let json = + r#"{"Range":{"min":"FF","max":"","isMinInclusive":true,"isMaxInclusive":false}}"#; + assert!(serde_json::from_str::(json).is_err()); + } +} diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/models/mod.rs b/sdk/cosmos/azure_data_cosmos_driver/src/models/mod.rs index 3dd87bcf6c5..13de94f6161 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/src/models/mod.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/src/models/mod.rs @@ -13,6 +13,7 @@ mod account_reference; mod activity_id; mod connection_string; mod consistency_level; +mod continuation_token; pub(crate) mod cosmos_headers; mod cosmos_operation; mod cosmos_resource_reference; @@ -31,8 +32,10 @@ pub(crate) mod vector_session_token; pub(crate) use cosmos_headers::request_header_names; #[allow(dead_code)] pub mod effective_partition_key; +mod feed_range; #[allow(dead_code)] mod murmur_hash; +mod operation_target; #[allow(dead_code)] pub mod partition_key_range; #[allow(dead_code)] @@ -42,6 +45,8 @@ pub use account_reference::{AccountReference, AccountReferenceBuilder, Credentia pub use activity_id::ActivityId; pub use connection_string::ConnectionString; pub(crate) use consistency_level::DefaultConsistencyLevel; +pub use continuation_token::ContinuationToken; +pub(crate) use continuation_token::ResolvedToken; pub use cosmos_headers::{ AutoscaleAutoUpgradePolicy, AutoscaleThroughputPolicy, CosmosRequestHeaders, CosmosResponseHeaders, OfferAutoscaleSettings, @@ -52,7 +57,10 @@ pub(crate) use cosmos_resource_reference::ResourcePaths; pub use cosmos_response::CosmosResponse; pub use cosmos_status::CosmosStatus; pub use cosmos_status::SubStatusCode; +pub use effective_partition_key::EffectivePartitionKey; pub use etag::{ETag, Precondition}; +pub use feed_range::FeedRange; +pub use operation_target::OperationTarget; pub use partition_key::{PartitionKey, PartitionKeyValue}; pub use request_charge::RequestCharge; pub use resource_reference::ContainerReference; @@ -370,6 +378,40 @@ impl ResourceType { ) } + /// Returns true if the given [`OperationTarget`] is valid for this resource type. + /// + /// Each resource type only supports a subset of targeting modes: + /// - Non-partitioned resources (`DatabaseAccount`, `Database`, `DocumentCollection`, + /// `PartitionKeyRange`, `Offer`) require [`OperationTarget::None`]. + /// - Documents require either a [`OperationTarget::PartitionKey`] or + /// [`OperationTarget::FeedRange`]. + /// - Server-side code resources (`StoredProcedure`, `Trigger`, `UserDefinedFunction`) + /// accept [`OperationTarget::None`] for CRUD and [`OperationTarget::PartitionKey`] + /// for execution. + pub fn is_valid_target(self, target: &OperationTarget) -> bool { + match self { + ResourceType::DatabaseAccount + | ResourceType::Database + | ResourceType::DocumentCollection + | ResourceType::PartitionKeyRange + | ResourceType::Offer => matches!(target, OperationTarget::None), + + ResourceType::Document => matches!( + target, + OperationTarget::PartitionKey(_) | OperationTarget::FeedRange(_) + ), + + ResourceType::StoredProcedure + | ResourceType::Trigger + | ResourceType::UserDefinedFunction => { + matches!( + target, + OperationTarget::None | OperationTarget::PartitionKey(_) + ) + } + } + } + /// Returns true if this resource type supports partition-level failover. /// /// Documents are partitioned for all operations except [`OperationType::QueryPlan`], @@ -833,4 +875,77 @@ mod tests { // Higher version (2) wins for globalLSN; region 1: max(100, 50) = 100 assert_eq!(merged.as_str(), "0:2#200#1=100"); } + + // --- ResourceType::is_valid_target --- + + #[test] + fn none_target_valid_for_database() { + assert!(ResourceType::Database.is_valid_target(&OperationTarget::None)); + } + + #[test] + fn none_target_valid_for_database_account() { + assert!(ResourceType::DatabaseAccount.is_valid_target(&OperationTarget::None)); + } + + #[test] + fn none_target_valid_for_document_collection() { + assert!(ResourceType::DocumentCollection.is_valid_target(&OperationTarget::None)); + } + + #[test] + fn none_target_valid_for_offer() { + assert!(ResourceType::Offer.is_valid_target(&OperationTarget::None)); + } + + #[test] + fn none_target_valid_for_partition_key_range() { + assert!(ResourceType::PartitionKeyRange.is_valid_target(&OperationTarget::None)); + } + + #[test] + fn none_target_invalid_for_document() { + assert!(!ResourceType::Document.is_valid_target(&OperationTarget::None)); + } + + #[test] + fn partition_key_valid_for_document() { + let pk = OperationTarget::PartitionKey(PartitionKey::from("pk")); + assert!(ResourceType::Document.is_valid_target(&pk)); + } + + #[test] + fn feed_range_valid_for_document() { + let fr = OperationTarget::FeedRange(FeedRange::full()); + assert!(ResourceType::Document.is_valid_target(&fr)); + } + + #[test] + fn partition_key_invalid_for_database() { + let pk = OperationTarget::PartitionKey(PartitionKey::from("pk")); + assert!(!ResourceType::Database.is_valid_target(&pk)); + } + + #[test] + fn feed_range_invalid_for_database() { + let fr = OperationTarget::FeedRange(FeedRange::full()); + assert!(!ResourceType::Database.is_valid_target(&fr)); + } + + #[test] + fn none_target_valid_for_stored_procedure() { + assert!(ResourceType::StoredProcedure.is_valid_target(&OperationTarget::None)); + } + + #[test] + fn partition_key_valid_for_stored_procedure() { + let pk = OperationTarget::PartitionKey(PartitionKey::from("pk")); + assert!(ResourceType::StoredProcedure.is_valid_target(&pk)); + } + + #[test] + fn feed_range_invalid_for_stored_procedure() { + let fr = OperationTarget::FeedRange(FeedRange::full()); + assert!(!ResourceType::StoredProcedure.is_valid_target(&fr)); + } } diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/models/operation_target.rs b/sdk/cosmos/azure_data_cosmos_driver/src/models/operation_target.rs new file mode 100644 index 00000000000..63fbbe7dea4 --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos_driver/src/models/operation_target.rs @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//! Operation targeting for Cosmos DB operations. + +use crate::models::{FeedRange, PartitionKey}; + +/// Describes how an operation targets the partition key space. +/// +/// Every [`CosmosOperation`](crate::models::CosmosOperation) carries an `OperationTarget` +/// that determines how the driver routes the request: +/// +/// - [`None`](Self::None) — account/database/container-level operations that have no +/// partition scope (e.g., create database, read container). +/// - [`PartitionKey`](Self::PartitionKey) — operations scoped to a single logical +/// partition. Always executed as a single request (point operation). +/// - [`FeedRange`](Self::FeedRange) — operations scoped to an EPK range that may +/// span one or more physical partitions (e.g., cross-partition queries). +#[derive(Clone, Debug)] +pub enum OperationTarget { + /// No partition scope. Used for account, database, and container-level operations. + /// + /// It is illegal to use this target for item-level operations inside a container. + None, + + /// Scoped to a single logical partition key. + /// + /// This can always be satisfied by a single request node — no fan-out required. + PartitionKey(PartitionKey), + + /// Scoped to a feed range (EPK range). + /// + /// The range may cover one or more physical partitions, including the full + /// container key space ([`FeedRange::full()`]). + FeedRange(FeedRange), +} + +impl OperationTarget { + /// Returns `true` if the target has a partition reference (i.e., it is not [`None`](Self::None)). + pub fn has_partition_reference(&self) -> bool { + !matches!(self, OperationTarget::None) + } +} diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/models/partition_key.rs b/sdk/cosmos/azure_data_cosmos_driver/src/models/partition_key.rs index f2c76e58d7b..800fa78e718 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/src/models/partition_key.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/src/models/partition_key.rs @@ -4,7 +4,10 @@ //! Partition key types for Cosmos DB operations. -use crate::models::FiniteF64; +use crate::models::{ + effective_partition_key::EffectivePartitionKey, FiniteF64, PartitionKeyKind, + PartitionKeyVersion, +}; use azure_core::http::headers::{AsHeaders, HeaderName, HeaderValue}; use std::{borrow::Cow, hash::Hash}; @@ -151,6 +154,12 @@ impl From for PartitionKeyValue { } impl PartitionKeyValue { + /// The Null partition key value. + pub const NULL: Self = Self(InnerPartitionKeyValue::Null); + + /// The Undefined partition key value. + pub const UNDEFINED: Self = Self(InnerPartitionKeyValue::Undefined); + /// Writes this value into a byte buffer using the V2 hashing encoding. /// /// Used by the effective partition key computation for MurmurHash3-128. @@ -300,6 +309,12 @@ impl Default for PartitionKey { } impl PartitionKey { + /// A single null partition key value. + pub const NULL: PartitionKeyValue = PartitionKeyValue::NULL; + + /// A single undefined partition key value. + pub const UNDEFINED: PartitionKeyValue = PartitionKeyValue::UNDEFINED; + /// An empty partition key, used to signal a cross-partition operation. pub const EMPTY: PartitionKey = PartitionKey(Vec::new()); @@ -322,6 +337,27 @@ impl PartitionKey { pub fn values(&self) -> &[PartitionKeyValue] { &self.0 } + + /// Returns a hex string representation of the partition key hash. + pub fn get_hashed_partition_key_string( + &self, + kind: PartitionKeyKind, + version: u8, + ) -> EffectivePartitionKey { + let version = match version { + 1 => PartitionKeyVersion::V1, + 2 => PartitionKeyVersion::V2, + unsupported => { + tracing::warn!( + "Partition key hashing version {} is unsupported in SDK API; defaulting to V2", + unsupported + ); + PartitionKeyVersion::V2 + } + }; + + EffectivePartitionKey::compute(&self.0, kind, version) + } } impl AsHeaders for PartitionKey { diff --git a/sdk/cosmos/azure_data_cosmos_driver/tests/emulator_tests/driver_backup_endpoints.rs b/sdk/cosmos/azure_data_cosmos_driver/tests/emulator_tests/driver_backup_endpoints.rs index 6706a24fa11..f086dff6775 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/tests/emulator_tests/driver_backup_endpoints.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/tests/emulator_tests/driver_backup_endpoints.rs @@ -81,7 +81,7 @@ async fn driver_operations_work_after_backup_boot() -> Result<(), Box let operation = CosmosOperation::create_database(account.clone()).with_body(body.into_bytes()); let result = driver - .execute_operation(operation, OperationOptions::default()) + .execute_point_operation(operation, OperationOptions::default()) .await; assert!( @@ -93,7 +93,7 @@ async fn driver_operations_work_after_backup_boot() -> Result<(), Box // Cleanup let db_ref = DatabaseReference::from_name(account, db_name); let _ = driver - .execute_operation( + .execute_point_operation( CosmosOperation::delete_database(db_ref), OperationOptions::default(), ) diff --git a/sdk/cosmos/azure_data_cosmos_driver/tests/framework/test_client.rs b/sdk/cosmos/azure_data_cosmos_driver/tests/framework/test_client.rs index a4d12ee6dec..7dac7768b46 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/tests/framework/test_client.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/tests/framework/test_client.rs @@ -15,6 +15,7 @@ use azure_data_cosmos_driver::{ options::{ConnectionPoolOptions, EmulatorServerCertValidation, OperationOptions}, }; use std::{error::Error, future::Future, sync::Arc}; +use tracing_subscriber::EnvFilter; use uuid::Uuid; use super::env::{ @@ -39,7 +40,13 @@ pub struct TestEnv { /// Returns `Ok(None)` if the environment is not configured and tests should be skipped. pub fn resolve_test_env() -> Result, Box> { let _ = tracing_subscriber::fmt::fmt() - .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_env_filter( + EnvFilter::builder() + // Tests with intentional failures cause noise, so we set the default level to "off" + // to silence them unless the user explicitly configures it. + .with_default_directive("off".parse().unwrap()) + .from_env_lossy(), + ) .try_init(); let test_mode = get_test_mode(); @@ -296,7 +303,7 @@ impl DriverTestRunContext { .with_body(body.into_bytes()); let result = driver - .execute_operation(operation, OperationOptions::default()) + .execute_point_operation(operation, OperationOptions::default()) .await?; // Check for success status (201 Created) @@ -326,7 +333,7 @@ impl DriverTestRunContext { let operation = CosmosOperation::delete_database(database.clone()); let result = driver - .execute_operation(operation, OperationOptions::default()) + .execute_point_operation(operation, OperationOptions::default()) .await?; // Check for success status (204 No Content) @@ -360,7 +367,7 @@ impl DriverTestRunContext { CosmosOperation::create_container(database.clone()).with_body(body.into_bytes()); let result = driver - .execute_operation(operation, OperationOptions::default()) + .execute_point_operation(operation, OperationOptions::default()) .await?; // Check for success status (201 Created) @@ -403,7 +410,7 @@ impl DriverTestRunContext { let operation = CosmosOperation::create_item(item_ref).with_body(body.to_vec()); let result = driver - .execute_operation(operation, OperationOptions::default()) + .execute_point_operation(operation, OperationOptions::default()) .await?; Ok(result) @@ -443,7 +450,7 @@ impl DriverTestRunContext { let operation = CosmosOperation::read_item(item_ref); let result = driver - .execute_operation(operation, OperationOptions::default()) + .execute_point_operation(operation, OperationOptions::default()) .await?; Ok(result) diff --git a/sdk/cosmos/azure_data_cosmos_driver/tests/multi_region_failover.rs b/sdk/cosmos/azure_data_cosmos_driver/tests/multi_region_failover.rs index fe6d2cad804..22ee8067bd6 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/tests/multi_region_failover.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/tests/multi_region_failover.rs @@ -54,7 +54,7 @@ async fn write_forbidden_triggers_refresh_and_failover() { ); let _ = driver - .execute_operation( + .execute_point_operation( CosmosOperation::read_database(db_ref), OperationOptions::default(), ) @@ -88,7 +88,7 @@ async fn session_not_available_retries_across_locations() { ); let _ = driver - .execute_operation( + .execute_point_operation( CosmosOperation::read_database(db_ref), OperationOptions::default(), ) diff --git a/sdk/cosmos/azure_data_cosmos_perf/src/operations/query_items.rs b/sdk/cosmos/azure_data_cosmos_perf/src/operations/query_items.rs index 00883f6ea5f..3088e3ef8cd 100644 --- a/sdk/cosmos/azure_data_cosmos_perf/src/operations/query_items.rs +++ b/sdk/cosmos/azure_data_cosmos_perf/src/operations/query_items.rs @@ -7,8 +7,8 @@ use std::sync::Arc; use std::time::Duration; use async_trait::async_trait; -use azure_data_cosmos::clients::ContainerClient; use azure_data_cosmos::Query; +use azure_data_cosmos::{clients::ContainerClient, query::QueryScope}; use futures::StreamExt; use super::{extract_backend_duration, Operation}; @@ -40,7 +40,8 @@ impl Operation for QueryItemsOperation { Query::from("SELECT * FROM c WHERE c.partition_key = @pk").with_parameter("@pk", pk)?; let mut stream = container - .query_items::(query, pk, None)? + .query_items::(query, QueryScope::partition(pk), None) + .await? .into_pages(); // Sum backend durations across pages so a multi-page query reports