diff --git a/sdk/cosmos/azure_data_cosmos/CHANGELOG.md b/sdk/cosmos/azure_data_cosmos/CHANGELOG.md index 1d0c37f05c6..f7c9133a346 100644 --- a/sdk/cosmos/azure_data_cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure_data_cosmos/CHANGELOG.md @@ -4,8 +4,12 @@ ### Features Added +- Added `CosmosClientBuilder::with_gateway20_disabled(bool)` to opt out of the new Gateway 2.0 transport, which is now enabled by default. Gateway 2.0 routes data-plane requests through a regional proxy that forwards RNTBD-over-HTTP/2 to the backend. Set this to `true` to fall back to the direct gateway transport — useful for workloads that depend on the published gateway latency SLAs (Gateway 2.0 is not currently covered by them) or that need the direct-gateway behavior for diagnostics. ([#4319](https://github.com/Azure/azure-sdk-for-rust/pull/4319)) + ### Breaking Changes +- Consolidated SDK fault-injection types as re-exports from `azure_data_cosmos_driver::fault_injection`. `FaultInjectionRule`, `FaultInjectionCondition`, `FaultInjectionResult`, `CustomResponse`, `FaultInjectionErrorType`, `FaultOperationType`, and the matching builders are now provided by the driver crate. Field access is via accessor methods (e.g., `rule.id()`, `condition.region()`, `response.body()`) rather than direct field reads. The SDK retains only `FaultInjectionClientBuilder` (gateway-side transport wrapper). ([#4319](https://github.com/Azure/azure-sdk-for-rust/pull/4319)) + ### Bugs Fixed ### Other Changes diff --git a/sdk/cosmos/azure_data_cosmos/build.rs b/sdk/cosmos/azure_data_cosmos/build.rs index 9862a46d7ee..b22c7b395db 100644 --- a/sdk/cosmos/azure_data_cosmos/build.rs +++ b/sdk/cosmos/azure_data_cosmos/build.rs @@ -6,5 +6,7 @@ // unknown cfg names are warned/denied unless explicitly declared via check-cfg. fn main() { // Allow `#[cfg_attr(not(test_category = "..."), ignore)]` in `tests/*.rs`. - println!("cargo:rustc-check-cfg=cfg(test_category, values(\"emulator\", \"multi_write\", \"split\"))"); + println!( + "cargo:rustc-check-cfg=cfg(test_category, values(\"emulator\", \"multi_write\", \"split\", \"gateway20\"))" + ); } diff --git a/sdk/cosmos/azure_data_cosmos/docs/sdk-to-driver-cutover.md b/sdk/cosmos/azure_data_cosmos/docs/sdk-to-driver-cutover.md index bb3b53e0b8b..8d4e3eaf20a 100644 --- a/sdk/cosmos/azure_data_cosmos/docs/sdk-to-driver-cutover.md +++ b/sdk/cosmos/azure_data_cosmos/docs/sdk-to-driver-cutover.md @@ -315,42 +315,24 @@ The gateway pipeline tracked this via `CosmosRequest` (which held the final URL) ## Fault Injection Wiring -When cutting `read_item` over to the driver, the SDK's fault injection tests initially failed because the two execution paths (gateway and driver) have **independent fault injection systems**. This section documents how they were connected. +The SDK no longer ships a parallel fault-injection type system. All fault-injection types — [`FaultInjectionRule`], [`FaultInjectionCondition`], [`FaultInjectionResult`], [`CustomResponse`], [`FaultInjectionErrorType`], [`FaultOperationType`], and the matching builders — are re-exported directly from the driver crate (`azure_data_cosmos_driver::fault_injection`) by `azure_data_cosmos::fault_injection`. The SDK only owns: -### Problem +- [`FaultInjectionClientBuilder`] — produces the `azure_core::http::Transport` that the SDK pipeline plugs in (i.e., a `FaultClient` HTTP client wrapper that evaluates driver rules against in-flight gateway requests). +- A small private `fault_operation_for_sdk(SdkOperationType, SdkResourceType) → Option` adapter so `CosmosRequest::add_fault_injection_headers` can stamp the right operation tag on the outbound headers. -The SDK and driver each have their own fault injection module (`azure_data_cosmos::fault_injection` and `azure_data_cosmos_driver::fault_injection`). They define parallel but separate types (`FaultInjectionRule`, `FaultInjectionCondition`, `FaultInjectionResult`, etc.) with identical variants but different Rust types. Prior to this work, only the gateway pipeline received fault injection rules — the driver was built without them. - -### Solution: Rule Translation with Shared State - -The bridge module (`driver_bridge.rs`) includes `sdk_fi_rules_to_driver_fi_rules()`, which translates SDK fault injection rules into driver fault injection rules. The translation covers: - -- `FaultOperationType` — variant-by-variant match (identical variant names) -- `FaultInjectionErrorType` — variant-by-variant match -- `FaultInjectionCondition` — `RegionName` → `Region`, operation type and container ID mapped directly -- `FaultInjectionResult` — `Duration` → `Option`, probability copied -- Timing fields — `start_time: Instant` → `Option`, `end_time` and `hit_limit` copied - -### Shared Mutable State - -SDK `FaultInjectionRule` has `enabled: Arc` and `hit_count: Arc` that tests mutate at runtime (`.disable()`, `.enable()`, `.hit_count()`). The driver's `FaultInjectionRuleBuilder` accepts external `Arc`s via `with_shared_state()`, so both the SDK gateway path and the driver path reference the **same atomic state**. This means: - -- Calling `.disable()` on the SDK rule also disables it in the driver -- Hit counts are shared — both paths increment the same counter -- Tests that toggle rules or assert hit counts work correctly across both paths +Because both transports (gateway and driver) consume the **same** `Arc` instances now, there is no translation step and no shared-state plumbing — toggling `enable()`/`disable()`, hit-count increments, and `hit_limit` enforcement all happen against one canonical rule object. ### Wiring in `CosmosClientBuilder` In `CosmosClientBuilder::build()`: -1. Before the `FaultInjectionClientBuilder` is consumed for the gateway transport, `rules()` extracts a reference to the SDK rules -2. `sdk_fi_rules_to_driver_fi_rules()` translates them to driver rules with shared state -3. The translated rules are passed to `CosmosDriverRuntimeBuilder::with_fault_injection_rules()` -4. The SDK's `fault_injection` Cargo feature now forwards to the driver's `fault_injection` feature +1. The `FaultInjectionClientBuilder::rules()` accessor returns `&[Arc]` — already the driver type, so the SDK simply clones the slice (`fault_builder.rules().to_vec()`). +2. The cloned rules are passed to `CosmosDriverRuntimeBuilder::with_fault_injection_rules()` so the driver's own fault-injection HTTP client can evaluate them. +3. The `FaultInjectionClientBuilder` is then consumed to build the gateway transport, which wraps the inner `HttpClient` with a `FaultClient` that evaluates the same rules. ### Test Patterns for Future Cutover -When cutting over additional operations, **no additional fault injection wiring is needed** — it's handled once at the `CosmosClientBuilder` level. However, tests need to account for two behavioral differences: +When cutting over additional operations, **no fault-injection wiring changes are needed** — it's all wired once at `CosmosClientBuilder::build()`. However, tests need to account for two behavioral differences between gateway-routed and driver-routed operations: **`request_url()` returns `None` for driver-routed operations:** @@ -378,24 +360,7 @@ let rule = FaultInjectionRuleBuilder::new("test", error) This asymmetry will disappear once all operations are driver-routed, since there will be only one hit-counting path. -### `custom_response` Translation - -Translation of `CustomResponse` (synthetic HTTP responses) is not yet implemented. None of the current tests use custom responses for `ReadItem` operations. When needed, the bridge function should be extended to translate `CustomResponse` fields (`status_code`, `headers`, `body`). - -### Consolidating to Driver Fault Injection After Cutover - -The current dual-system architecture (SDK fault injection + driver fault injection + translation bridge) exists only because the cutover is incremental — some operations still go through the gateway while others go through the driver. Once **all** operations are routed through the driver: - -1. **Drop `azure_data_cosmos::fault_injection`** — the SDK's HTTP-client-level fault interception module becomes unreachable. Delete the entire `src/fault_injection/` directory. -2. **Re-export driver types** — the SDK re-exports the driver's fault injection types directly: - - ```rust - #[cfg(feature = "fault_injection")] - pub use azure_data_cosmos_driver::fault_injection; - ``` +### Final State After Cutover -3. **Remove the translation layer** — `sdk_fi_rules_to_driver_fi_rules()` in `driver_bridge.rs` and the `shared_enabled()`/`shared_hit_count()` accessors on the SDK rule are no longer needed. -4. **Simplify `CosmosClientBuilder`** — `with_fault_injection()` accepts `Vec>` directly and passes them to `CosmosDriverRuntimeBuilder::with_fault_injection_rules()`. No translation, no cloning, no intermediary builder. -5. **Update tests** — tests construct driver `FaultInjectionRule` directly (same builders, same API) instead of SDK rules. +Once **all** operations are routed through the driver, the SDK-side `FaultInjectionClientBuilder` and `FaultClient` HTTP wrapper become unreachable too — the driver-runtime fault-injection HTTP client is the single source of truth. At that point `azure_data_cosmos::fault_injection` collapses into a pure `pub use azure_data_cosmos_driver::fault_injection;` re-export (or is dropped entirely). -At that point the SDK has **no fault injection logic of its own** — it's a pass-through to the driver, matching the overall "SDK as thin wrapper" goal. The driver is the single source of truth for all transport-related concerns including fault injection. diff --git a/sdk/cosmos/azure_data_cosmos/src/clients/cosmos_client_builder.rs b/sdk/cosmos/azure_data_cosmos/src/clients/cosmos_client_builder.rs index ff7d6953bb3..ad35092b727 100644 --- a/sdk/cosmos/azure_data_cosmos/src/clients/cosmos_client_builder.rs +++ b/sdk/cosmos/azure_data_cosmos/src/clients/cosmos_client_builder.rs @@ -93,6 +93,15 @@ pub struct CosmosClientBuilder { fault_injection_builder: Option, /// Fallback endpoints tried when the primary endpoint is unavailable. backup_endpoints: Vec, + /// Operator override for the Gateway 2.0 transport. + /// + /// `None` (the default) leaves the underlying driver in charge of + /// routing — Gateway 2.0 is selected automatically whenever the + /// account advertises a Gateway 2.0 endpoint and HTTP/2 is allowed. + /// `Some(true)` forces every request through the standard gateway + /// transport via [`with_gateway20_disabled`](Self::with_gateway20_disabled); + /// `Some(false)` explicitly opts in (matching the default behaviour). + gateway20_disabled: Option, } impl CosmosClientBuilder { @@ -168,6 +177,41 @@ impl CosmosClientBuilder { self } + /// Disables the Gateway 2.0 transport for this client. + /// + /// Gateway 2.0 is the next-generation Cosmos DB dataplane transport: + /// SDK connections terminate at a regional Gateway 2.0 proxy that + /// forwards RNTBD-over-HTTP/2 to the backend. **Gateway 2.0 is enabled + /// by default** — whenever the account advertises a Gateway 2.0 endpoint + /// the SDK routes eligible dataplane operations through it and falls + /// back to the standard gateway only for operations Gateway 2.0 cannot + /// serve (e.g. metadata requests or accounts that do not advertise a + /// Gateway 2.0 endpoint). + /// + /// Pass `true` to opt out and force every request through the standard + /// gateway transport. The standard gateway path remains supported and + /// stable — disabling Gateway 2.0 is the recommended workaround if you + /// hit a regression on the new transport. + /// + /// # Latency caveat + /// + /// Gateway 2.0 traffic flows through a proxy that is + /// **not currently covered by the regional Cosmos DB latency SLA**. + /// Workloads with strict P99 latency requirements should opt out via + /// `with_gateway20_disabled(true)` until the proxy reaches general + /// availability. The extra hop also means Gateway 2.0 may add measurable + /// latency relative to the standard gateway in some regions. + /// + /// # Arguments + /// + /// * `disabled` - `true` to suppress Gateway 2.0 and force the standard + /// gateway transport; `false` (or leaving the builder untouched) keeps + /// the default Gateway 2.0 behaviour. + pub fn with_gateway20_disabled(mut self, disabled: bool) -> Self { + self.gateway20_disabled = Some(disabled); + self + } + /// Registers a throughput control group on the driver runtime. /// /// Groups define throughput policies (priority level, throughput bucket) that @@ -287,9 +331,10 @@ impl CosmosClientBuilder { Option, Vec>, ) = if let Some(fault_builder) = self.fault_injection_builder { - // Translate rules for the driver before the builder is consumed. - let driver_rules = - crate::driver_bridge::sdk_fi_rules_to_driver_fi_rules(fault_builder.rules()); + // SDK fault-injection rules are now driver `FaultInjectionRule`s + // (re-exported through `crate::fault_injection`), so the driver + // can consume them directly without a translation step. + let driver_rules = fault_builder.rules().to_vec(); let fault_builder = match base_client { Some(client) => fault_builder.with_inner_client(client), None => fault_builder, @@ -425,6 +470,9 @@ impl CosmosClientBuilder { EmulatorServerCertValidation::DangerousDisabled, ); } + if let Some(disabled) = self.gateway20_disabled { + pool_builder = pool_builder.with_gateway20_disabled(disabled); + } driver_runtime_builder = driver_runtime_builder.with_connection_pool(pool_builder.build()?); #[cfg(feature = "fault_injection")] diff --git a/sdk/cosmos/azure_data_cosmos/src/constants.rs b/sdk/cosmos/azure_data_cosmos/src/constants.rs index 733ce125243..da118ef50b0 100644 --- a/sdk/cosmos/azure_data_cosmos/src/constants.rs +++ b/sdk/cosmos/azure_data_cosmos/src/constants.rs @@ -18,6 +18,8 @@ macro_rules! cosmos_headers { /// A list of all Cosmos DB specific headers that should be allowed in logging. pub const COSMOS_ALLOWED_HEADERS: &[&HeaderName] = &[ $(&$name,)* + &azure_data_cosmos_driver::constants::GATEWAY20_OPERATION_TYPE, + &azure_data_cosmos_driver::constants::GATEWAY20_RESOURCE_TYPE, ]; }; } @@ -185,9 +187,6 @@ cosmos_headers! { COSMOS_QUORUM_ACKED_LLSN => "x-ms-cosmos-quorum-acked-llsn", REQUEST_DURATION_MS => "x-ms-request-duration-ms", COSMOS_INTERNAL_PARTITION_ID => "x-ms-cosmos-internal-partition-id", - // Thin Client - THINCLIENT_PROXY_OPERATION_TYPE => "x-ms-thinclient-proxy-operation-type", - THINCLIENT_PROXY_RESOURCE_TYPE => "x-ms-thinclient-proxy-resource-type", // Client ID CLIENT_ID => "x-ms-client-id", // these are not actually sent but are used internally for fault injection diff --git a/sdk/cosmos/azure_data_cosmos/src/cosmos_request.rs b/sdk/cosmos/azure_data_cosmos/src/cosmos_request.rs index 4ee5e0f1a59..c56782dd66d 100644 --- a/sdk/cosmos/azure_data_cosmos/src/cosmos_request.rs +++ b/sdk/cosmos/azure_data_cosmos/src/cosmos_request.rs @@ -2,7 +2,7 @@ // Licensed under the MIT License. #[cfg(feature = "fault_injection")] -use crate::fault_injection::FaultOperationType; +use crate::fault_injection::fault_operation_for_sdk; use crate::operation_context::OperationType; use crate::options::ExcludedRegions; use crate::request_context::RequestContext; @@ -153,10 +153,7 @@ impl CosmosRequest { #[cfg(feature = "fault_injection")] pub fn add_fault_injection_headers(&mut self) { - let fault_op = FaultOperationType::from_operation_and_resource( - &self.operation_type, - &self.resource_type, - ); + let fault_op = fault_operation_for_sdk(&self.operation_type, &self.resource_type); if let Some(op) = fault_op { self.headers.insert( diff --git a/sdk/cosmos/azure_data_cosmos/src/driver_bridge.rs b/sdk/cosmos/azure_data_cosmos/src/driver_bridge.rs index 2c842808a19..31c1cc400c8 100644 --- a/sdk/cosmos/azure_data_cosmos/src/driver_bridge.rs +++ b/sdk/cosmos/azure_data_cosmos/src/driver_bridge.rs @@ -94,128 +94,6 @@ fn driver_response_headers_to_headers(cosmos_headers: &CosmosResponseHeaders) -> headers } -/// Translates SDK fault injection rules into driver fault injection rules. -/// -/// The `enabled` and `hit_count` state is shared between the SDK and driver -/// rules via `Arc`, so toggling a rule in tests affects both paths. -#[cfg(feature = "fault_injection")] -pub(crate) fn sdk_fi_rules_to_driver_fi_rules( - sdk_rules: &[std::sync::Arc], -) -> Vec> { - use crate::fault_injection::{ - FaultInjectionErrorType as SdkErrorType, FaultOperationType as SdkOpType, - }; - use azure_data_cosmos_driver::fault_injection::{ - self as driver_fi, FaultInjectionConditionBuilder as DriverConditionBuilder, - FaultInjectionResultBuilder as DriverResultBuilder, - FaultInjectionRuleBuilder as DriverRuleBuilder, - }; - use azure_data_cosmos_driver::options::Region; - - sdk_rules - .iter() - .map(|sdk_rule| { - // Translate condition - let mut cond_builder = DriverConditionBuilder::new(); - if let Some(op) = &sdk_rule.condition.operation_type { - let driver_op = match op { - SdkOpType::ReadItem => driver_fi::FaultOperationType::ReadItem, - SdkOpType::QueryItem => driver_fi::FaultOperationType::QueryItem, - SdkOpType::CreateItem => driver_fi::FaultOperationType::CreateItem, - SdkOpType::UpsertItem => driver_fi::FaultOperationType::UpsertItem, - SdkOpType::ReplaceItem => driver_fi::FaultOperationType::ReplaceItem, - SdkOpType::DeleteItem => driver_fi::FaultOperationType::DeleteItem, - SdkOpType::PatchItem => driver_fi::FaultOperationType::PatchItem, - SdkOpType::BatchItem => driver_fi::FaultOperationType::BatchItem, - SdkOpType::ChangeFeedItem => driver_fi::FaultOperationType::ChangeFeedItem, - SdkOpType::MetadataReadContainer => { - driver_fi::FaultOperationType::MetadataReadContainer - } - SdkOpType::MetadataReadDatabaseAccount => { - driver_fi::FaultOperationType::MetadataReadDatabaseAccount - } - SdkOpType::MetadataQueryPlan => { - driver_fi::FaultOperationType::MetadataQueryPlan - } - SdkOpType::MetadataPartitionKeyRanges => { - driver_fi::FaultOperationType::MetadataPartitionKeyRanges - } - }; - cond_builder = cond_builder.with_operation_type(driver_op); - } - if let Some(region) = &sdk_rule.condition.region { - cond_builder = cond_builder.with_region(Region::new(region.to_string())); - } - if let Some(container_id) = &sdk_rule.condition.container_id { - cond_builder = cond_builder.with_container_id(container_id.clone()); - } - - // Translate result - let mut result_builder = DriverResultBuilder::new(); - if let Some(err) = &sdk_rule.result.error_type { - let driver_err = match err { - SdkErrorType::InternalServerError => { - driver_fi::FaultInjectionErrorType::InternalServerError - } - SdkErrorType::TooManyRequests => { - driver_fi::FaultInjectionErrorType::TooManyRequests - } - SdkErrorType::ReadSessionNotAvailable => { - driver_fi::FaultInjectionErrorType::ReadSessionNotAvailable - } - SdkErrorType::Timeout => driver_fi::FaultInjectionErrorType::Timeout, - SdkErrorType::ServiceUnavailable => { - driver_fi::FaultInjectionErrorType::ServiceUnavailable - } - SdkErrorType::PartitionIsGone => { - driver_fi::FaultInjectionErrorType::PartitionIsGone - } - SdkErrorType::WriteForbidden => { - driver_fi::FaultInjectionErrorType::WriteForbidden - } - SdkErrorType::DatabaseAccountNotFound => { - driver_fi::FaultInjectionErrorType::DatabaseAccountNotFound - } - SdkErrorType::ConnectionError => { - driver_fi::FaultInjectionErrorType::ConnectionError - } - SdkErrorType::ResponseTimeout => { - driver_fi::FaultInjectionErrorType::ResponseTimeout - } - }; - result_builder = result_builder.with_error(driver_err); - } - if sdk_rule.result.delay > std::time::Duration::ZERO { - result_builder = result_builder.with_delay(sdk_rule.result.delay); - } - let prob = sdk_rule.result.probability(); - if prob < 1.0 { - result_builder = result_builder.with_probability(prob); - } - // Note: custom_response translation is skipped for now. - // None of the current failing tests use custom responses. - - // Build driver rule with shared state - let mut rule_builder = - DriverRuleBuilder::new(sdk_rule.id.clone(), result_builder.build()) - .with_condition(cond_builder.build()) - .with_shared_state(sdk_rule.shared_enabled(), sdk_rule.shared_hit_count()); - - if let Some(end_time) = sdk_rule.end_time { - rule_builder = rule_builder.with_end_time(end_time); - } - if let Some(hit_limit) = sdk_rule.hit_limit { - rule_builder = rule_builder.with_hit_limit(hit_limit); - } - // SDK start_time is always set (Instant::now() by default). - // Driver start_time is Option. - rule_builder = rule_builder.with_start_time(sdk_rule.start_time); - - std::sync::Arc::new(rule_builder.build()) - }) - .collect() -} - #[cfg(test)] mod tests { use super::*; diff --git a/sdk/cosmos/azure_data_cosmos/src/fault_injection/client_builder.rs b/sdk/cosmos/azure_data_cosmos/src/fault_injection/client_builder.rs index 6ed1d330f90..97c5b9571f3 100644 --- a/sdk/cosmos/azure_data_cosmos/src/fault_injection/client_builder.rs +++ b/sdk/cosmos/azure_data_cosmos/src/fault_injection/client_builder.rs @@ -8,7 +8,7 @@ use std::sync::Arc; use azure_core::http::Transport; use super::http_client::FaultClient; -use super::rule::FaultInjectionRule; +use super::FaultInjectionRule; /// Builder for creating a fault injection client. /// diff --git a/sdk/cosmos/azure_data_cosmos/src/fault_injection/condition.rs b/sdk/cosmos/azure_data_cosmos/src/fault_injection/condition.rs deleted file mode 100644 index 1201497c7d4..00000000000 --- a/sdk/cosmos/azure_data_cosmos/src/fault_injection/condition.rs +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -//! Defines conditions for when fault injection rules should be applied. - -use super::FaultOperationType; -use crate::regions::Region; - -/// Defines the condition under which a fault injection rule should be applied. -#[derive(Clone, Default, Debug)] -pub struct FaultInjectionCondition { - /// The type of operation to which the fault injection applies. - pub operation_type: Option, - /// The region to which the fault injection applies. - pub region: Option, - /// The container ID to which the fault injection applies. - pub container_id: Option, -} - -/// Builder for creating a FaultInjectionCondition. -#[derive(Default)] -pub struct FaultInjectionConditionBuilder { - operation_type: Option, - region: Option, - container_id: Option, -} - -impl FaultInjectionConditionBuilder { - /// Creates a new FaultInjectionConditionBuilder with default values. - pub fn new() -> Self { - Self { - operation_type: None, - region: None, - container_id: None, - } - } - - /// Sets the operation type to which the fault injection applies. - pub fn with_operation_type(mut self, operation_type: FaultOperationType) -> Self { - self.operation_type = Some(operation_type); - self - } - - /// Sets the region to which the fault injection applies. - pub fn with_region(mut self, region: Region) -> Self { - self.region = Some(region); - self - } - - /// Sets the container ID to which the fault injection applies. - pub fn with_container_id(mut self, container_id: impl Into) -> Self { - self.container_id = Some(container_id.into()); - self - } - - /// Builds the FaultInjectionCondition. - pub fn build(self) -> FaultInjectionCondition { - FaultInjectionCondition { - operation_type: self.operation_type, - region: self.region, - container_id: self.container_id, - } - } -} - -#[cfg(test)] -mod tests { - use super::FaultInjectionConditionBuilder; - - #[test] - fn builder_default() { - let builder = FaultInjectionConditionBuilder::default(); - let condition = builder.build(); - assert!(condition.operation_type.is_none()); - assert!(condition.region.is_none()); - assert!(condition.container_id.is_none()); - } -} diff --git a/sdk/cosmos/azure_data_cosmos/src/fault_injection/http_client.rs b/sdk/cosmos/azure_data_cosmos/src/fault_injection/http_client.rs index 1a18e20e196..55887c012a3 100644 --- a/sdk/cosmos/azure_data_cosmos/src/fault_injection/http_client.rs +++ b/sdk/cosmos/azure_data_cosmos/src/fault_injection/http_client.rs @@ -3,10 +3,9 @@ // cSpell:ignore evals -use super::result::FaultInjectionResult; -use super::rule::FaultInjectionRule; -use super::FaultInjectionErrorType; -use super::FaultOperationType; +use super::{ + FaultInjectionErrorType, FaultInjectionResult, FaultInjectionRule, FaultOperationType, +}; use crate::constants::{self, SubStatusCode}; use async_trait::async_trait; use azure_core::error::ErrorKind; @@ -43,20 +42,22 @@ impl FaultClient { return false; } - // Check if the rule has started - if now < rule.start_time { - return false; + // Check if the rule has started (driver default = always-active) + if let Some(start) = rule.start_time() { + if now < start { + return false; + } } // Check if the rule has expired - if let Some(end_time) = rule.end_time { + if let Some(end_time) = rule.end_time() { if now >= end_time { return false; } } // Check if we've exceeded the hit limit on the rule - if let Some(hit_limit) = rule.hit_limit { + if let Some(hit_limit) = rule.hit_limit() { if rule.hit_count() >= hit_limit { return false; } @@ -67,11 +68,11 @@ impl FaultClient { /// Checks if the request matches the rule's condition. fn matches_condition(&self, request: &Request, rule: &FaultInjectionRule) -> bool { - let condition = &rule.condition; + let condition = rule.condition(); let mut matches = true; // Check operation type if specified - if let Some(expected_op) = condition.operation_type { + if let Some(expected_op) = condition.operation_type() { let request_op = request .headers() .get_optional_str(&constants::FAULT_INJECTION_OPERATION) @@ -89,14 +90,14 @@ impl FaultClient { } // Check region if specified - if let Some(region) = &condition.region { + if let Some(region) = condition.region() { if !request.url().as_str().contains(region.as_str()) { matches = false; } } // Check container ID if specified - if let Some(container_id) = &condition.container_id { + if let Some(container_id) = condition.container_id() { if !request.url().as_str().contains(container_id) { matches = false; } @@ -127,16 +128,16 @@ impl FaultClient { } // Check for custom response first (takes precedence over error injection) - if let Some(ref custom) = server_error.custom_response { + if let Some(custom) = server_error.custom_response() { return Some(Ok(AsyncRawResponse::from_bytes( - custom.status_code, - custom.headers.clone(), - custom.body.clone(), + custom.status_code(), + custom.headers().clone(), + custom.body().to_vec(), ))); } // Generate the appropriate error based on error type - let error_type = match server_error.error_type { + let error_type = match server_error.error_type() { Some(et) => et, None => return None, // No error type set, pass through }; @@ -196,6 +197,14 @@ impl FaultClient { Some(SubStatusCode::DATABASE_ACCOUNT_NOT_FOUND), "Database Account Not Found - Injected fault", ), + // The driver enum is `#[non_exhaustive]`; new variants + // surface as a generic injected service-unavailable until the + // SDK is taught to render them. + _ => ( + StatusCode::ServiceUnavailable, + None, + "Unknown injected fault", + ), }; let raw_response = sub_status.map(|ss| { @@ -221,10 +230,7 @@ impl FaultClient { impl HttpClient for FaultClient { async fn execute_request(&self, request: &Request) -> azure_core::Result { // Find applicable rule and clone the result if needed - let (fault_result, matched_rule): ( - Option, - Option>, - ) = { + let fault_result: Option = { let rules = self.rules.lock().unwrap(); let mut applicable_rule_index: Option = None; @@ -239,9 +245,9 @@ impl HttpClient for FaultClient { if let Some(index) = applicable_rule_index { let rule = &rules[index]; rule.increment_hit_count(); - (Some(rule.result.clone()), Some(Arc::clone(rule))) + Some(rule.result().clone()) } else { - (None, None) + None } }; @@ -262,33 +268,19 @@ impl HttpClient for FaultClient { .remove(constants::FAULT_INJECTION_OPERATION); // No fault injection or delay-only fault, proceed with actual request - let result = self.inner.execute_request(&clean_request).await; - - // Record response status only for true spy rules: no error_type, - // no custom_response, and no delay. This excludes probability-skipped - // faults and any rule that injected a delay. - if let (Some(rule), Some(ref fr), Ok(ref response)) = - (&matched_rule, &fault_result, &result) - { - if fr.error_type.is_none() - && fr.custom_response.is_none() - && fr.delay == Duration::ZERO - { - rule.record_passthrough_status(response.status()); - } - } - - result + self.inner.execute_request(&clean_request).await }; // Apply delay after the request is sent if let Some(result) = fault_result { - if result.delay > Duration::ZERO { - let delay = azure_core::time::Duration::try_from(result.delay) - .unwrap_or(azure_core::time::Duration::ZERO); - azure_core::async_runtime::get_async_runtime() - .sleep(delay) - .await; + if let Some(delay) = result.delay() { + if delay > Duration::ZERO { + let delay = azure_core::time::Duration::try_from(delay) + .unwrap_or(azure_core::time::Duration::ZERO); + azure_core::async_runtime::get_async_runtime() + .sleep(delay) + .await; + } } } @@ -301,13 +293,13 @@ mod tests { use super::FaultClient; use crate::constants::{SubStatusCode, SUB_STATUS}; use crate::fault_injection::{ - CustomResponse, FaultInjectionConditionBuilder, FaultInjectionErrorType, + CustomResponseBuilder, FaultInjectionConditionBuilder, FaultInjectionErrorType, FaultInjectionResultBuilder, FaultInjectionRuleBuilder, FaultOperationType, }; use crate::regions::Region; use async_trait::async_trait; use azure_core::error::ErrorKind; - use azure_core::http::{headers::Headers, AsyncRawResponse, HttpClient, Method, Request, Url}; + use azure_core::http::{AsyncRawResponse, HttpClient, Method, Request, Url}; use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -730,11 +722,11 @@ mod tests { let body = b"{\"id\": \"test-account\"}".to_vec(); let result = FaultInjectionResultBuilder::new() - .with_custom_response(CustomResponse { - status_code: azure_core::http::StatusCode::Ok, - headers: Headers::new(), - body: body.clone(), - }) + .with_custom_response( + CustomResponseBuilder::new(azure_core::http::StatusCode::Ok) + .with_body(body.clone()) + .build(), + ) .build(); let rule = FaultInjectionRuleBuilder::new("custom-response-rule", result).build(); diff --git a/sdk/cosmos/azure_data_cosmos/src/fault_injection/mod.rs b/sdk/cosmos/azure_data_cosmos/src/fault_injection/mod.rs index f0cee5374be..ede03f2b5bd 100644 --- a/sdk/cosmos/azure_data_cosmos/src/fault_injection/mod.rs +++ b/sdk/cosmos/azure_data_cosmos/src/fault_injection/mod.rs @@ -3,9 +3,17 @@ //! Fault injection framework for testing Cosmos DB client behavior under error conditions. //! -//! This module provides a fault injection framework that intercepts HTTP requests at the -//! transport layer, below the retry policy. When a fault is injected, it triggers the same -//! retry and failover behavior as a real service error. This enables testing of: +//! This module wraps the driver's fault-injection primitives — every type +//! except [`FaultInjectionClientBuilder`] is re-exported directly from +//! [`azure_data_cosmos_driver::fault_injection`]. The SDK only owns the +//! [`FaultInjectionClientBuilder`] (which produces an [`azure_core::http::Transport`] +//! that the SDK pipeline plugs in) and a small adapter for translating SDK-side +//! `OperationType` / `ResourceType` pairs into the driver's +//! [`FaultOperationType`]. +//! +//! Below the transport layer, fault injection intercepts HTTP requests and +//! triggers the same retry and failover behavior as a real service error. +//! It enables testing of: //! //! - Error handling for various HTTP status codes (503, 500, 429, 408, etc.) //! - Retry logic and backoff behavior @@ -27,7 +35,7 @@ //! configured builder to [`CosmosClientBuilder::with_fault_injection()`](crate::CosmosClientBuilder::with_fault_injection) //! to enable fault injection and wrap the HTTP transport with a fault-injecting client. //! - [`FaultInjectionCondition`] — Defines when a fault should be applied, filtering by -//! operation type, region, or container ID. +//! operation type, region, container ID, or transport kind. //! - [`FaultInjectionResult`] — Defines what error to inject, including error type, delay, //! and probability. //! - [`FaultInjectionRule`] — Combines a condition with a result and additional controls @@ -92,177 +100,76 @@ //! Rules are evaluated in the order they were added. The first matching rule is applied. //! All specified conditions in a [`FaultInjectionCondition`] must match (AND logic): //! if no conditions are specified, the rule matches all requests. -//! mod client_builder; -mod condition; mod http_client; -mod result; -mod rule; - -use std::fmt; -use std::str::FromStr; - -use crate::operation_context::OperationType; -use crate::resource_context::ResourceType; pub use client_builder::FaultInjectionClientBuilder; -pub use condition::{FaultInjectionCondition, FaultInjectionConditionBuilder}; -pub use result::{ - CustomResponse, CustomResponseBuilder, FaultInjectionResult, FaultInjectionResultBuilder, + +#[doc(inline)] +pub use azure_data_cosmos_driver::fault_injection::{ + CustomResponse, CustomResponseBuilder, FaultInjectionCondition, FaultInjectionConditionBuilder, + FaultInjectionErrorType, FaultInjectionResult, FaultInjectionResultBuilder, FaultInjectionRule, + FaultInjectionRuleBuilder, FaultOperationType, }; -pub use rule::{FaultInjectionRule, FaultInjectionRuleBuilder}; -/// Represents different server error types that can be injected for fault testing. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum FaultInjectionErrorType { - /// 500 from server. - InternalServerError, - /// 429 from server. - TooManyRequests, - /// 404-1002 from server. - ReadSessionNotAvailable, - /// 408 from server. - Timeout, - /// Simulate service unavailable (503). - ServiceUnavailable, - /// 410-1002 from server. - PartitionIsGone, - /// 403-3 Forbidden from server. - WriteForbidden, - /// 403-1008 Forbidden from server. - DatabaseAccountNotFound, - /// Simulates a connection failure (e.g., connection refused, DNS failure). - /// Produces an `ErrorKind::Io` error, not an HTTP response error. - ConnectionError, - /// Simulates a response timeout (request sent but no response received). - /// Produces an `ErrorKind::Io` error, not an HTTP response error. - ResponseTimeout, -} +/// Re-export of the driver's [`TransportKind`] enum so SDK consumers can +/// scope fault-injection rules to a specific transport (Gateway 1.x vs +/// Gateway 2.0) without depending on the driver crate directly. +pub use azure_data_cosmos_driver::diagnostics::TransportKind; -/// The type of operation to which the fault injection applies. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum FaultOperationType { - /// Read items. - ReadItem, - /// Query items. - QueryItem, - /// Create item. - CreateItem, - /// Upsert item. - UpsertItem, - /// Replace item. - ReplaceItem, - /// Delete item. - DeleteItem, - /// Patch item. - PatchItem, - /// Batch item. - BatchItem, - /// Read change feed items. - ChangeFeedItem, - /// Read container request. - MetadataReadContainer, - /// Read database account request. - MetadataReadDatabaseAccount, - /// Query query plan request. - MetadataQueryPlan, - /// Partition key ranges request. - MetadataPartitionKeyRanges, -} +use crate::operation_context::OperationType as SdkOperationType; +use crate::resource_context::ResourceType as SdkResourceType; -impl FaultOperationType { - /// Returns the string representation of this operation type. - pub fn as_str(&self) -> &'static str { - match self { - FaultOperationType::ReadItem => "ReadItem", - FaultOperationType::QueryItem => "QueryItem", - FaultOperationType::CreateItem => "CreateItem", - FaultOperationType::UpsertItem => "UpsertItem", - FaultOperationType::ReplaceItem => "ReplaceItem", - FaultOperationType::DeleteItem => "DeleteItem", - FaultOperationType::PatchItem => "PatchItem", - FaultOperationType::BatchItem => "BatchItem", - FaultOperationType::ChangeFeedItem => "ChangeFeedItem", - FaultOperationType::MetadataReadContainer => "MetadataReadContainer", - FaultOperationType::MetadataReadDatabaseAccount => "MetadataReadDatabaseAccount", - FaultOperationType::MetadataQueryPlan => "MetadataQueryPlan", - FaultOperationType::MetadataPartitionKeyRanges => "MetadataPartitionKeyRanges", +/// Maps an SDK-side `(OperationType, ResourceType)` pair to the driver's +/// [`FaultOperationType`]. +/// +/// This mirrors `FaultOperationType::from_operation_and_resource` on the +/// driver, but takes SDK enums directly so SDK callers don't need to convert +/// to driver enums first. Returns `None` if the combination doesn't map to a +/// known fault operation type. +pub(crate) fn fault_operation_for_sdk( + operation_type: &SdkOperationType, + resource_type: &SdkResourceType, +) -> Option { + match (operation_type, resource_type) { + (SdkOperationType::Read, SdkResourceType::Documents) => Some(FaultOperationType::ReadItem), + (SdkOperationType::Query, SdkResourceType::Documents) => { + Some(FaultOperationType::QueryItem) } - } - - /// Converts an operation type and resource type pair into a fault injection operation type. - /// - /// Returns `None` if the combination does not map to a known fault operation type. - pub fn from_operation_and_resource( - operation_type: &OperationType, - resource_type: &ResourceType, - ) -> Option { - match (operation_type, resource_type) { - (OperationType::Read, ResourceType::Documents) => Some(FaultOperationType::ReadItem), - (OperationType::Query, ResourceType::Documents) => Some(FaultOperationType::QueryItem), - (OperationType::Create, ResourceType::Documents) => { - Some(FaultOperationType::CreateItem) - } - (OperationType::Upsert, ResourceType::Documents) => { - Some(FaultOperationType::UpsertItem) - } - (OperationType::Replace, ResourceType::Documents) => { - Some(FaultOperationType::ReplaceItem) - } - (OperationType::Delete, ResourceType::Documents) => { - Some(FaultOperationType::DeleteItem) - } - (OperationType::Patch, ResourceType::Documents) => Some(FaultOperationType::PatchItem), - (OperationType::Batch, ResourceType::Documents) => Some(FaultOperationType::BatchItem), - (OperationType::ReadFeed, ResourceType::Documents) => { - Some(FaultOperationType::ChangeFeedItem) - } - (OperationType::Read, ResourceType::Containers) => { - Some(FaultOperationType::MetadataReadContainer) - } - (OperationType::Read, ResourceType::DatabaseAccount) => { - Some(FaultOperationType::MetadataReadDatabaseAccount) - } - (OperationType::QueryPlan, ResourceType::Documents) => { - Some(FaultOperationType::MetadataQueryPlan) - } - (OperationType::ReadFeed, ResourceType::PartitionKeyRanges) => { - Some(FaultOperationType::MetadataPartitionKeyRanges) - } - _ => None, + (SdkOperationType::Create, SdkResourceType::Documents) => { + Some(FaultOperationType::CreateItem) } - } -} - -impl fmt::Display for FaultOperationType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(self.as_str()) - } -} - -impl FromStr for FaultOperationType { - type Err = (); - - /// Parses a string into a `FaultOperationType`. - /// - /// Returns `Err(())` if the string is not a recognized operation type. - fn from_str(s: &str) -> Result { - match s { - "ReadItem" => Ok(FaultOperationType::ReadItem), - "QueryItem" => Ok(FaultOperationType::QueryItem), - "CreateItem" => Ok(FaultOperationType::CreateItem), - "UpsertItem" => Ok(FaultOperationType::UpsertItem), - "ReplaceItem" => Ok(FaultOperationType::ReplaceItem), - "DeleteItem" => Ok(FaultOperationType::DeleteItem), - "PatchItem" => Ok(FaultOperationType::PatchItem), - "BatchItem" => Ok(FaultOperationType::BatchItem), - "ChangeFeedItem" => Ok(FaultOperationType::ChangeFeedItem), - "MetadataReadContainer" => Ok(FaultOperationType::MetadataReadContainer), - "MetadataReadDatabaseAccount" => Ok(FaultOperationType::MetadataReadDatabaseAccount), - "MetadataQueryPlan" => Ok(FaultOperationType::MetadataQueryPlan), - "MetadataPartitionKeyRanges" => Ok(FaultOperationType::MetadataPartitionKeyRanges), - _ => Err(()), + (SdkOperationType::Upsert, SdkResourceType::Documents) => { + Some(FaultOperationType::UpsertItem) + } + (SdkOperationType::Replace, SdkResourceType::Documents) => { + Some(FaultOperationType::ReplaceItem) + } + (SdkOperationType::Delete, SdkResourceType::Documents) => { + Some(FaultOperationType::DeleteItem) + } + (SdkOperationType::Patch, SdkResourceType::Documents) => { + Some(FaultOperationType::PatchItem) + } + (SdkOperationType::Batch, SdkResourceType::Documents) => { + Some(FaultOperationType::BatchItem) + } + (SdkOperationType::ReadFeed, SdkResourceType::Documents) => { + Some(FaultOperationType::ChangeFeedItem) + } + (SdkOperationType::Read, SdkResourceType::Containers) => { + Some(FaultOperationType::MetadataReadContainer) + } + (SdkOperationType::Read, SdkResourceType::DatabaseAccount) => { + Some(FaultOperationType::MetadataReadDatabaseAccount) + } + (SdkOperationType::QueryPlan, SdkResourceType::Documents) => { + Some(FaultOperationType::MetadataQueryPlan) + } + (SdkOperationType::ReadFeed, SdkResourceType::PartitionKeyRanges) => { + Some(FaultOperationType::MetadataPartitionKeyRanges) } + _ => None, } } diff --git a/sdk/cosmos/azure_data_cosmos/src/fault_injection/result.rs b/sdk/cosmos/azure_data_cosmos/src/fault_injection/result.rs deleted file mode 100644 index af0db1a8ebd..00000000000 --- a/sdk/cosmos/azure_data_cosmos/src/fault_injection/result.rs +++ /dev/null @@ -1,240 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -//! Defines fault injection results including server errors. - -use std::time::Duration; - -use azure_core::http::{ - headers::{HeaderName, HeaderValue, Headers}, - StatusCode, -}; - -use crate::constants::SubStatusCode; - -use super::FaultInjectionErrorType; - -/// A synthetic response to return when a fault injection rule matches. -/// -/// Instead of injecting an error, this returns a successful response with -/// the specified status code, headers, and body. Useful for mocking service -/// responses such as `GetDatabaseAccount` in tests. -#[derive(Clone, Debug)] -pub struct CustomResponse { - /// The HTTP status code for the synthetic response. - pub status_code: StatusCode, - /// The headers for the synthetic response. - pub headers: Headers, - /// The body for the synthetic response. - pub body: Vec, -} - -/// Builder for creating a [`CustomResponse`]. -/// -/// Provides a fluent API for constructing synthetic HTTP responses -/// for fault injection testing. -/// -/// # Example -/// -/// ```rust -/// use azure_data_cosmos::fault_injection::CustomResponseBuilder; -/// use azure_core::http::StatusCode; -/// -/// let response = CustomResponseBuilder::new(StatusCode::Forbidden) -/// .with_sub_status(3) -/// .with_body(b"Write Forbidden".to_vec()) -/// .build(); -/// -/// assert_eq!(response.status_code, StatusCode::Forbidden); -/// ``` -pub struct CustomResponseBuilder { - status_code: StatusCode, - headers: Headers, - body: Vec, -} - -impl CustomResponseBuilder { - /// Creates a new builder with the specified HTTP status code. - pub fn new(status_code: StatusCode) -> Self { - Self { - status_code, - headers: Headers::new(), - body: Vec::new(), - } - } - - /// Adds a header to the response. - pub fn with_header( - mut self, - name: impl Into, - value: impl Into, - ) -> Self { - self.headers.insert(name, value); - self - } - - /// Sets the `x-ms-substatus` header to the given numeric sub-status code. - /// - /// This is a convenience method equivalent to calling - /// `with_header("x-ms-substatus", code.to_string())`. - pub fn with_sub_status(self, code: impl Into) -> Self { - let code = code.into(); - self.with_header(crate::constants::SUB_STATUS, code.to_string()) - } - - /// Sets the response body. - pub fn with_body(mut self, body: impl Into>) -> Self { - self.body = body.into(); - self - } - - /// Builds the [`CustomResponse`]. - pub fn build(self) -> CustomResponse { - CustomResponse { - status_code: self.status_code, - headers: self.headers, - body: self.body, - } - } -} - -/// Represents a server error to be injected. -#[derive(Clone, Debug)] -pub struct FaultInjectionResult { - /// The type of server error to inject. - pub error_type: Option, - /// A custom response to return instead of injecting an error. - pub custom_response: Option, - /// Delay before injecting the error. - pub delay: Duration, - /// Probability of injecting the error (0.0 to 1.0). - probability: f32, -} - -impl FaultInjectionResult { - /// Returns the probability of injecting the fault (0.0 to 1.0). - pub fn probability(&self) -> f32 { - self.probability - } -} - -/// Builder for creating a FaultInjectionResult. -pub struct FaultInjectionResultBuilder { - error_type: Option, - custom_response: Option, - delay: Duration, - probability: f32, -} - -impl FaultInjectionResultBuilder { - /// Creates a new FaultInjectionResultBuilder with default values. - pub fn new() -> Self { - Self { - error_type: None, - custom_response: None, - delay: Duration::ZERO, - probability: 1.0, - } - } - - /// Sets the error type to inject. - pub fn with_error(mut self, error_type: FaultInjectionErrorType) -> Self { - self.error_type = Some(error_type); - self - } - - /// Sets a custom response to return instead of injecting an error. - /// - /// When set, the fault injection rule returns this synthetic response - /// rather than forwarding the request to the real service. This takes - /// precedence over `error_type` if both are set. - pub fn with_custom_response(mut self, response: CustomResponse) -> Self { - self.custom_response = Some(response); - self - } - - /// Sets the delay before injecting the error. - pub fn with_delay(mut self, delay: Duration) -> Self { - self.delay = delay; - self - } - - /// Sets the probability of injecting the error (0.0 to 1.0). - pub fn with_probability(mut self, probability: f32) -> Self { - self.probability = probability.clamp(0.0, 1.0); - self - } - - /// Builds the FaultInjectionResult. - /// - pub fn build(self) -> FaultInjectionResult { - FaultInjectionResult { - error_type: self.error_type, - custom_response: self.custom_response, - delay: self.delay, - probability: self.probability, - } - } -} - -impl Default for FaultInjectionResultBuilder { - fn default() -> Self { - Self::new() - } -} - -#[cfg(test)] -mod tests { - use super::{CustomResponse, FaultInjectionResultBuilder}; - use crate::fault_injection::FaultInjectionErrorType; - use azure_core::http::{headers::Headers, StatusCode}; - use std::time::Duration; - - #[test] - fn builder_default_values() { - let error = FaultInjectionResultBuilder::new() - .with_error(FaultInjectionErrorType::Timeout) - .build(); - - assert_eq!(error.error_type.unwrap(), FaultInjectionErrorType::Timeout); - assert_eq!(error.delay, Duration::ZERO); - assert!((error.probability() - 1.0).abs() < f32::EPSILON); - } - - #[test] - fn builder_probability_clamped_above() { - let error = FaultInjectionResultBuilder::new() - .with_error(FaultInjectionErrorType::ServiceUnavailable) - .with_probability(1.5) - .build(); - - assert!((error.probability() - 1.0).abs() < f32::EPSILON); - } - - #[test] - fn builder_probability_clamped_below() { - let error = FaultInjectionResultBuilder::new() - .with_error(FaultInjectionErrorType::ServiceUnavailable) - .with_probability(-0.5) - .build(); - - assert!(error.probability().abs() < f32::EPSILON); - } - - #[test] - fn builder_with_custom_response() { - let body = b"{\"test\": true}".to_vec(); - let result = FaultInjectionResultBuilder::new() - .with_custom_response(CustomResponse { - status_code: StatusCode::Ok, - headers: Headers::new(), - body: body.clone(), - }) - .build(); - - assert!(result.error_type.is_none()); - let custom = result.custom_response.unwrap(); - assert_eq!(custom.status_code, StatusCode::Ok); - assert_eq!(custom.body, body); - } -} diff --git a/sdk/cosmos/azure_data_cosmos/src/fault_injection/rule.rs b/sdk/cosmos/azure_data_cosmos/src/fault_injection/rule.rs deleted file mode 100644 index 2d87aa8bc50..00000000000 --- a/sdk/cosmos/azure_data_cosmos/src/fault_injection/rule.rs +++ /dev/null @@ -1,249 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -//! Defines fault injection rules that combine conditions and results. - -use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; -use std::sync::{Arc, Mutex}; -use std::time::Instant; - -use azure_core::http::StatusCode; - -use super::condition::FaultInjectionCondition; -use super::result::FaultInjectionResult; - -/// A fault injection rule that defines when and how to inject faults. -#[derive(Debug)] -pub struct FaultInjectionRule { - /// The condition under which to inject the fault. - pub condition: FaultInjectionCondition, - /// The result to inject when the condition is met. - pub result: FaultInjectionResult, - /// The absolute time at which the rule becomes active. - pub start_time: Instant, - /// The absolute time at which the rule expires, if set. - pub end_time: Option, - /// The total hit limit of the rule. - pub hit_limit: Option, - /// Unique identifier for the fault injection scenario. - pub id: String, - /// Whether the rule is currently enabled. - enabled: Arc, - /// Number of times the rule has been matched (including matches where no fault was injected). - hit_count: Arc, - /// HTTP status codes of responses for matched requests that passed through without fault injection. - passthrough_statuses: Mutex>, -} - -/// Cloning snapshots the current `hit_count` and `enabled` state rather than -/// resetting them, so a clone of a rule that has been hit 5 times starts at 5. -impl Clone for FaultInjectionRule { - fn clone(&self) -> Self { - Self { - condition: self.condition.clone(), - result: self.result.clone(), - start_time: self.start_time, - end_time: self.end_time, - hit_limit: self.hit_limit, - id: self.id.clone(), - enabled: Arc::new(AtomicBool::new(self.enabled.load(Ordering::SeqCst))), - hit_count: Arc::new(AtomicU32::new(self.hit_count.load(Ordering::SeqCst))), - passthrough_statuses: Mutex::new(self.passthrough_statuses.lock().unwrap().clone()), - } - } -} - -impl FaultInjectionRule { - /// Returns whether the rule is currently enabled. - pub fn is_enabled(&self) -> bool { - self.enabled.load(Ordering::SeqCst) - } - - /// Enables the rule. - pub fn enable(&self) { - self.enabled.store(true, Ordering::SeqCst); - } - - /// Disables the rule. - pub fn disable(&self) { - self.enabled.store(false, Ordering::SeqCst); - } - - /// Returns the number of times this rule has been matched. - /// - /// The hit count is incremented each time the rule's condition matches a - /// request, regardless of whether the fault was actually applied (e.g., - /// probability-based skipping still increments the count). - pub fn hit_count(&self) -> u32 { - self.hit_count.load(Ordering::SeqCst) - } - - /// Increments the hit count by one. - pub(super) fn increment_hit_count(&self) { - self.hit_count.fetch_add(1, Ordering::SeqCst); - } - - /// Resets the hit count to zero. - pub fn reset_hit_count(&self) { - self.hit_count.store(0, Ordering::SeqCst); - } - - /// Returns a shared reference to the enabled flag for cross-path state sharing. - pub(crate) fn shared_enabled(&self) -> Arc { - Arc::clone(&self.enabled) - } - - /// Returns a shared reference to the hit count for cross-path state sharing. - pub(crate) fn shared_hit_count(&self) -> Arc { - Arc::clone(&self.hit_count) - } - - /// Records the HTTP status code of a response for a matched request that - /// passed through without fault injection (spy/passthrough mode). - pub(super) fn record_passthrough_status(&self, status: StatusCode) { - self.passthrough_statuses.lock().unwrap().push(status); - } - - /// Returns the HTTP status codes of responses for matched requests that - /// passed through without fault injection. - /// - /// When a rule matches a request but does not inject a fault (e.g., no - /// `error_type` or `custom_response` is set), the real service response - /// status is recorded here. This enables "spy" rules that observe requests - /// without modifying them. - /// - /// The history grows unbounded for the lifetime of the rule. This is - /// designed for test scenarios with a bounded number of requests. - pub fn passthrough_statuses(&self) -> Vec { - self.passthrough_statuses.lock().unwrap().clone() - } -} - -/// Builder for creating a fault injection rule. -pub struct FaultInjectionRuleBuilder { - /// The condition under which to inject the fault. - condition: FaultInjectionCondition, - /// The result to inject when the condition is met. - result: FaultInjectionResult, - /// The absolute time at which the rule becomes active. - start_time: Instant, - /// The absolute time at which the rule expires. - end_time: Option, - /// The total hit limit of the rule. - hit_limit: Option, - /// Unique identifier for the fault injection scenario. - id: String, -} - -impl FaultInjectionRuleBuilder { - /// Creates a new FaultInjectionRuleBuilder with default values. - /// - /// By default the rule starts immediately and never expires. - pub fn new(id: impl Into, result: FaultInjectionResult) -> Self { - Self { - condition: FaultInjectionCondition::default(), - result, - start_time: Instant::now(), - end_time: None, - hit_limit: None, - id: id.into(), - } - } - - /// Sets the condition for when to inject the fault. - pub fn with_condition(mut self, condition: FaultInjectionCondition) -> Self { - self.condition = condition; - self - } - - /// Sets the result to inject when the condition is met. - pub fn with_result(mut self, result: FaultInjectionResult) -> Self { - self.result = result; - self - } - - /// Sets the absolute time at which the rule becomes active. - pub fn with_start_time(mut self, start_time: Instant) -> Self { - self.start_time = start_time; - self - } - - /// Sets the absolute time at which the rule expires. - pub fn with_end_time(mut self, end_time: Instant) -> Self { - self.end_time = Some(end_time); - self - } - - /// Sets the total hit limit of the rule. - pub fn with_hit_limit(mut self, hit_limit: u32) -> Self { - self.hit_limit = Some(hit_limit); - self - } - - /// Builds the FaultInjectionRule. - pub fn build(self) -> FaultInjectionRule { - FaultInjectionRule { - condition: self.condition, - result: self.result, - start_time: self.start_time, - end_time: self.end_time, - hit_limit: self.hit_limit, - id: self.id, - enabled: Arc::new(AtomicBool::new(true)), - hit_count: Arc::new(AtomicU32::new(0)), - passthrough_statuses: Mutex::new(Vec::new()), - } - } -} - -#[cfg(test)] -mod tests { - use super::FaultInjectionRuleBuilder; - use crate::fault_injection::{FaultInjectionErrorType, FaultInjectionResultBuilder}; - use std::time::Instant; - - fn create_test_error() -> crate::fault_injection::FaultInjectionResult { - FaultInjectionResultBuilder::new() - .with_error(FaultInjectionErrorType::Timeout) - .build() - } - - #[test] - fn builder_default_values() { - let before = Instant::now(); - let rule = FaultInjectionRuleBuilder::new("test-rule", create_test_error()).build(); - - assert_eq!(rule.id, "test-rule"); - assert!(rule.start_time >= before); - assert!(rule.start_time <= Instant::now()); - assert!(rule.end_time.is_none()); - assert!(rule.hit_limit.is_none()); - assert!(rule.condition.operation_type.is_none()); - assert!(rule.is_enabled()); - assert_eq!(rule.hit_count(), 0); - } - - #[test] - fn hit_count_increments() { - let rule = FaultInjectionRuleBuilder::new("hit-test", create_test_error()).build(); - - assert_eq!(rule.hit_count(), 0); - rule.increment_hit_count(); - assert_eq!(rule.hit_count(), 1); - rule.increment_hit_count(); - rule.increment_hit_count(); - assert_eq!(rule.hit_count(), 3); - } - - #[test] - fn reset_hit_count_clears_counter() { - let rule = FaultInjectionRuleBuilder::new("reset-test", create_test_error()).build(); - - rule.increment_hit_count(); - rule.increment_hit_count(); - assert_eq!(rule.hit_count(), 2); - - rule.reset_hit_count(); - assert_eq!(rule.hit_count(), 0); - } -} diff --git a/sdk/cosmos/azure_data_cosmos/tests/emulator_tests/cosmos_fault_injection.rs b/sdk/cosmos/azure_data_cosmos/tests/emulator_tests/cosmos_fault_injection.rs index 75213af2319..3416657a83a 100644 --- a/sdk/cosmos/azure_data_cosmos/tests/emulator_tests/cosmos_fault_injection.rs +++ b/sdk/cosmos/azure_data_cosmos/tests/emulator_tests/cosmos_fault_injection.rs @@ -12,7 +12,7 @@ use super::framework; use azure_core::{http::StatusCode, Uuid}; use azure_data_cosmos::fault_injection::{ FaultInjectionClientBuilder, FaultInjectionConditionBuilder, FaultInjectionErrorType, - FaultInjectionResultBuilder, FaultInjectionRuleBuilder, FaultOperationType, + FaultInjectionResultBuilder, FaultInjectionRuleBuilder, FaultOperationType, TransportKind, }; use azure_data_cosmos::models::{ContainerProperties, ThroughputProperties}; use framework::{get_effective_hub_endpoint, TestClient, TestOptions}; @@ -951,7 +951,7 @@ pub async fn fault_injection_enable_disable_rule() -> Result<(), Box> .build(), ); - assert_eq!(rule.id, "enable-disable-test"); + assert_eq!(rule.id(), "enable-disable-test"); assert!(rule.is_enabled()); let rule_handle = Arc::clone(&rule); @@ -1020,3 +1020,90 @@ pub async fn fault_injection_enable_disable_rule() -> Result<(), Box> ) .await } + +// ---------------------------------------------------------------------------- +// Gateway 2.0 fault injection coverage (Phase 6) +// ---------------------------------------------------------------------------- + +/// Gateway 2.0 ConnectionError should fall back to the standard gateway +/// transparently — the client must not surface the connection failure to the +/// caller when a usable fallback transport exists. +/// +/// The rule is scoped to [`TransportKind::Gateway20`] via +/// `with_transport_kind`, so it only fires on Gateway 2.0 traffic and never +/// on standard-gateway requests. +/// +/// **Limitation**: the SDK does not yet expose a public Gateway 2.0 enable +/// API on `CosmosClientOptions`, so the SDK currently never selects the +/// Gateway 2.0 transport. Until that toggle lands, this test is gated behind +/// the `gateway20` test category. Once the SDK toggle ships, the assertion +/// should change from "rule never fires" to "read SUCCEEDS via the +/// standard-gateway fallback". +#[tokio::test] +#[cfg_attr( + not(test_category = "gateway20"), + ignore = "requires test_category 'gateway20'" +)] +pub async fn gateway20_connection_error_falls_back_to_standard_gateway( +) -> Result<(), Box> { + let server_error = FaultInjectionResultBuilder::new() + .with_error(FaultInjectionErrorType::ConnectionError) + .with_probability(1.0) + .build(); + + let condition = FaultInjectionConditionBuilder::new() + .with_operation_type(FaultOperationType::ReadItem) + .with_transport_kind(TransportKind::Gateway20) + .build(); + + let rule = FaultInjectionRuleBuilder::new("gateway20-conn-error-fallback", server_error) + .with_condition(condition) + .build(); + + let fault_builder = FaultInjectionClientBuilder::new().with_rule(Arc::new(rule)); + + TestClient::run_with_unique_db( + async |run_context, db_client| { + let container_id = format!("Container-{}", Uuid::new_v4()); + let container_client = run_context + .create_container_with_throughput( + db_client, + ContainerProperties::new(container_id.clone(), "/partition_key".into()), + ThroughputProperties::manual(400), + ) + .await?; + + let unique_id = Uuid::new_v4().to_string(); + let item = create_test_item(&unique_id); + let pk = format!("Partition-{}", unique_id); + let item_id = format!("Item-{}", unique_id); + + container_client + .create_item(&pk, &item_id, &item, None) + .await?; + + let fault_client = run_context + .fault_client() + .expect("fault client should be available"); + let fault_db_client = fault_client.database_client(db_client.id()); + let fault_container_client = fault_db_client.container_client(&container_id).await?; + + // Once the SDK exposes a public Gateway 2.0 enable API, this read + // should SUCCEED via the standard-gateway fallback (the rule + // fires only on Gateway 2.0, leaving the fallback transport + // untouched). + let result = fault_container_client + .read_item::(&pk, &item_id, None) + .await; + assert!( + result.is_ok(), + "Read should succeed via the standard-gateway fallback when \ + the rule is scoped to Gateway 2.0" + ); + + Ok(()) + }, + Some(TestOptions::new().with_fault_injection_builder(fault_builder)), + ) + .await +} diff --git a/sdk/cosmos/azure_data_cosmos/tests/emulator_tests/gateway20_e2e.rs b/sdk/cosmos/azure_data_cosmos/tests/emulator_tests/gateway20_e2e.rs new file mode 100644 index 00000000000..7d4f399ff5c --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos/tests/emulator_tests/gateway20_e2e.rs @@ -0,0 +1,631 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//! End-to-end tests for the Gateway 2.0 transport, exercised through the +//! `azure_data_cosmos` SDK surface (not the underlying driver crate). +//! +//! These tests run against a pre-provisioned Gateway 2.0 account. The +//! endpoint and primary key are read from the +//! `AZURE_COSMOS_GW20_ENDPOINT` and `AZURE_COSMOS_GW20_KEY` environment +//! variables and gated by the `gateway20` test category. They are skipped by +//! default; the main Cosmos Rust pipeline (`sdk/cosmos/ci.yml`) injects those +//! env vars from the `azure-sdk-tests-cosmos` service connection's secret +//! variable group, and the `Cosmos_gateway20_live_test` matrix entry sets the +//! `testCategory` to `gateway20` (or `gateway20_multi_region`) so the tests +//! run in CI against the live account. +//! +//! ## What these tests assert today +//! +//! [`CosmosClientBuilder::with_gateway20_disabled`] now propagates the +//! Gateway 2.0 toggle into the underlying driver, so the tests exercise the +//! real SDK opt-in path against the live account. +//! +//! Each implemented test: +//! +//! * Builds a [`CosmosClient`] with `with_gateway20_disabled(false)` (or +//! `true`, for the operator-override scenario), pointing at the +//! `AZURE_COSMOS_GW20_ENDPOINT/_KEY` account. +//! * Provisions a fresh database + container and drives the operation +//! appropriate to the test (CRUD, query, batch, point read). +//! * Asserts the operation succeeds and the standard +//! [`CosmosDiagnostics`] fields (activity ID + server duration) are +//! populated. +//! +//! ## Future work (`TODO`) +//! +//! The SDK-level [`CosmosDiagnostics`] type does not yet surface the driver's +//! `TransportKind` — that gap is documented on `CosmosDiagnostics` itself +//! ("will be expanded ... once the SDK pipeline is ported to the driver's +//! transport pipeline"). Once that exposure lands, each test should be +//! tightened to assert `TransportKind::Gateway20` (or `StandardGateway` for +//! the override case) on the diagnostics instance returned from the +//! operation. +//! +//! The change-feed test stays a placeholder until the SDK gains a public +//! change-feed API on `ContainerClient` (only the routing-layer change-feed +//! plumbing exists today; there is no `ContainerClient::change_feed` to +//! call from a public test). + +#![cfg(feature = "key_auth")] + +use azure_core::credentials::Secret; +use azure_data_cosmos::models::{ContainerProperties, PartitionKeyDefinition}; +use azure_data_cosmos::{ + CosmosAccountEndpoint, CosmosAccountReference, CosmosClient, Query, Region, RoutingStrategy, + TransactionalBatch, +}; +use futures::StreamExt; +use serde::{Deserialize, Serialize}; + +fn read_env(name: &str) -> Option { + std::env::var(name).ok().filter(|v| !v.trim().is_empty()) +} + +/// Returns `Some((endpoint, key))` only when both env vars are set. +fn live_credentials() -> Option<(String, String)> { + Some(( + read_env("AZURE_COSMOS_GW20_ENDPOINT")?, + read_env("AZURE_COSMOS_GW20_KEY")?, + )) +} + +/// Build a [`CosmosClient`] against the live Gateway 2.0 account. +/// +/// `gateway20_disabled = false` opts the client in to Gateway 2.0; passing +/// `true` exercises the operator-override path that pins the client to the +/// standard gateway even when the account advertises a Gateway 2.0 endpoint. +async fn build_client( + endpoint: &str, + key: &str, + gateway20_disabled: bool, +) -> Result> { + let endpoint: CosmosAccountEndpoint = endpoint.parse()?; + let account_ref = + CosmosAccountReference::with_master_key(endpoint, Secret::from(key.to_string())); + let client = CosmosClient::builder() + .with_gateway20_disabled(gateway20_disabled) + .build(account_ref, RoutingStrategy::ProximityTo(Region::EAST_US)) + .await?; + Ok(client) +} + +/// Provisions a fresh database + container scoped to the test invocation and +/// returns the database name (so the caller can drop it) and a container +/// client to drive operations against. +async fn provision_database_and_container( + client: &CosmosClient, +) -> Result<(String, azure_data_cosmos::clients::ContainerClient), Box> { + let unique = azure_core::Uuid::new_v4(); + let db_name = format!("gw20-test-db-{unique}"); + let container_name = format!("gw20-test-container-{unique}"); + + client.create_database(&db_name, None).await?; + let db_client = client.database_client(&db_name); + + let pk_def: PartitionKeyDefinition = "/pk".into(); + let properties = ContainerProperties::new(container_name.clone(), pk_def); + db_client.create_container(properties, None).await?; + let container_client = db_client.container_client(&container_name).await?; + + Ok((db_name, container_client)) +} + +async fn drop_database(client: &CosmosClient, db_name: &str) { + let db_client = client.database_client(db_name); + let _ = db_client.delete(None).await; +} + +#[derive(Debug, Deserialize, Serialize, PartialEq, Eq, Clone)] +struct Gw20TestItem { + id: String, + pk: String, + value: i64, + label: String, +} + +/// Drives a point CRUD round-trip (create → read → replace → delete) against +/// the live Gateway 2.0 account. +/// +/// TODO: tighten the per-response diagnostics check to assert +/// `TransportKind::Gateway20` once `CosmosDiagnostics` surfaces the +/// transport kind from the driver. +#[tokio::test] +#[cfg_attr( + not(test_category = "gateway20"), + ignore = "requires test_category 'gateway20' and AZURE_COSMOS_GW20_ENDPOINT/_KEY" +)] +pub async fn gateway20_point_crud_round_trip() -> Result<(), Box> { + let Some((endpoint, key)) = live_credentials() else { + return Ok(()); + }; + + let client = build_client(&endpoint, &key, false).await?; + let (db_name, container) = provision_database_and_container(&client).await?; + + let pk_value = format!("pk-{}", azure_core::Uuid::new_v4()); + let item_id = format!("item-{}", azure_core::Uuid::new_v4()); + let mut item = Gw20TestItem { + id: item_id.clone(), + pk: pk_value.clone(), + value: 1, + label: "initial".into(), + }; + + let create_resp = container + .create_item(&pk_value, &item_id, &item, None) + .await?; + assert!(create_resp.diagnostics().activity_id().is_some()); + assert!(create_resp.diagnostics().server_duration_ms().is_some()); + + let read_resp = container + .read_item::(&pk_value, &item_id, None) + .await?; + assert!(read_resp.diagnostics().activity_id().is_some()); + let read_item: Gw20TestItem = read_resp.into_model()?; + assert_eq!(read_item, item); + + item.value = 2; + item.label = "updated".into(); + let replace_resp = container + .replace_item(&pk_value, &item_id, &item, None) + .await?; + assert!(replace_resp.diagnostics().activity_id().is_some()); + + let delete_resp = container.delete_item(&pk_value, &item_id, None).await?; + assert!(delete_resp.diagnostics().activity_id().is_some()); + + drop_database(&client, &db_name).await; + Ok(()) +} + +/// Runs a SQL query through Gateway 2.0 and asserts the streamed pages all +/// route through the Gateway 2.0 transport. +/// +/// TODO: tighten the per-page diagnostics check to assert +/// `TransportKind::Gateway20` once the SDK exposes the driver transport +/// kind on the page diagnostics. +#[tokio::test] +#[cfg_attr( + not(test_category = "gateway20"), + ignore = "requires test_category 'gateway20' and AZURE_COSMOS_GW20_ENDPOINT/_KEY" +)] +pub async fn gateway20_query_streams() -> Result<(), Box> { + let Some((endpoint, key)) = live_credentials() else { + return Ok(()); + }; + + let client = build_client(&endpoint, &key, false).await?; + let (db_name, container) = provision_database_and_container(&client).await?; + + let pk_value = format!("pk-{}", azure_core::Uuid::new_v4()); + for i in 0..5 { + let item = Gw20TestItem { + id: format!("query-item-{i}"), + pk: pk_value.clone(), + value: i64::from(i), + label: format!("row-{i}"), + }; + let id = item.id.clone(); + container.create_item(&pk_value, &id, &item, None).await?; + } + + let query = Query::from("SELECT * FROM c ORDER BY c.value"); + let mut pages = container + .query_items::(query, pk_value.clone(), None)? + .into_pages(); + + let mut pages_seen = 0_usize; + let mut items_seen = 0_usize; + while let Some(page) = pages.next().await { + let page = page?; + pages_seen += 1; + assert!(page.diagnostics().activity_id().is_some()); + items_seen += page.items().len(); + } + assert!(pages_seen >= 1, "expected at least one query page"); + assert_eq!(items_seen, 5); + + drop_database(&client, &db_name).await; + Ok(()) +} + +/// Forces multi-page query pagination on Gateway 2.0 by setting +/// `x-ms-max-item-count: 2` and inserting more rows than fit on a single +/// page, then asserts that: +/// +/// * the query produces strictly more than one page, +/// * every row is returned exactly once with no cross-page duplicates, and +/// * pages chain via continuation tokens (the SDK's `Pager` plumbs the +/// response continuation header back as a request continuation header, +/// which the Gateway 2.0 wrap path serializes into RNTBD token `0x0006`). +/// +/// This is the end-to-end regression test for the request-side continuation +/// propagation bug fix: without it the proxy would always restart from page +/// one and return duplicates instead of advancing. +#[tokio::test] +#[cfg_attr( + not(test_category = "gateway20"), + ignore = "requires test_category 'gateway20' and AZURE_COSMOS_GW20_ENDPOINT/_KEY" +)] +pub async fn gateway20_query_paginates_via_continuation_tokens( +) -> Result<(), Box> { + use azure_core::http::headers::{HeaderName, HeaderValue}; + use azure_data_cosmos::options::{OperationOptions, QueryOptions}; + use std::collections::{HashMap, HashSet}; + + let Some((endpoint, key)) = live_credentials() else { + return Ok(()); + }; + + let client = build_client(&endpoint, &key, false).await?; + let (db_name, container) = provision_database_and_container(&client).await?; + + let pk_value = format!("pk-{}", azure_core::Uuid::new_v4()); + let total_items: usize = 7; + for i in 0..total_items { + let item = Gw20TestItem { + id: format!("page-item-{i}"), + pk: pk_value.clone(), + value: i as i64, + label: format!("row-{i}"), + }; + let id = item.id.clone(); + container.create_item(&pk_value, &id, &item, None).await?; + } + + let mut custom_headers: HashMap = HashMap::new(); + custom_headers.insert( + HeaderName::from_static("x-ms-max-item-count"), + HeaderValue::from_static("2"), + ); + let query_options = QueryOptions::default() + .with_operation_options(OperationOptions::default().with_custom_headers(custom_headers)); + + let query = Query::from("SELECT * FROM c ORDER BY c.value"); + let mut pages = container + .query_items::(query, pk_value.clone(), Some(query_options))? + .into_pages(); + + let mut pages_seen = 0_usize; + let mut ids_seen: HashSet = HashSet::new(); + while let Some(page) = pages.next().await { + let page = page?; + pages_seen += 1; + assert!( + page.diagnostics().activity_id().is_some(), + "every Gateway 2.0 page must surface an activity-id", + ); + for item in page.items() { + assert!( + ids_seen.insert(item.id.clone()), + "item {} returned twice — pagination did not advance (continuation token not propagated)", + item.id, + ); + } + } + + assert!( + pages_seen > 1, + "expected continuation-driven pagination to produce more than one page (got {pages_seen})", + ); + assert_eq!( + ids_seen.len(), + total_items, + "expected all {total_items} inserted rows; saw {} unique ids across {pages_seen} pages", + ids_seen.len(), + ); + + drop_database(&client, &db_name).await; + Ok(()) +} + +/// +/// TODO: tighten the diagnostics check to assert `TransportKind::Gateway20` +/// once the SDK surfaces the driver transport kind on batch diagnostics. +#[tokio::test] +#[cfg_attr( + not(test_category = "gateway20"), + ignore = "requires test_category 'gateway20' and AZURE_COSMOS_GW20_ENDPOINT/_KEY" +)] +pub async fn gateway20_transactional_batch() -> Result<(), Box> { + let Some((endpoint, key)) = live_credentials() else { + return Ok(()); + }; + + let client = build_client(&endpoint, &key, false).await?; + let (db_name, container) = provision_database_and_container(&client).await?; + + let pk_value = format!("pk-{}", azure_core::Uuid::new_v4()); + let item_a = Gw20TestItem { + id: "batch-a".into(), + pk: pk_value.clone(), + value: 10, + label: "a".into(), + }; + let item_b = Gw20TestItem { + id: "batch-b".into(), + pk: pk_value.clone(), + value: 20, + label: "b".into(), + }; + let upsert = Gw20TestItem { + id: "batch-c".into(), + pk: pk_value.clone(), + value: 30, + label: "c".into(), + }; + + let batch = TransactionalBatch::new(&pk_value) + .create_item(&item_a)? + .create_item(&item_b)? + .upsert_item(&upsert, None)?; + + let response = container.execute_transactional_batch(batch, None).await?; + let body = response.into_model()?; + let codes: Vec = body.results().iter().map(|r| r.status_code()).collect(); + assert_eq!(codes, vec![201, 201, 201]); + + drop_database(&client, &db_name).await; + Ok(()) +} + +/// Drives a `LatestVersion` change feed iterator through Gateway 2.0. +/// +/// TODO: implement once the SDK exposes a public change-feed API on +/// `ContainerClient`. Only routing-layer change-feed plumbing exists today +/// (`execute_partition_key_range_read_change_feed`); there is no public +/// `ContainerClient::change_feed` entry point yet, so the test cannot +/// exercise the SDK surface end-to-end. Tracking item: SDK change-feed +/// public API. +#[tokio::test] +#[cfg_attr( + not(test_category = "gateway20"), + ignore = "requires test_category 'gateway20' and AZURE_COSMOS_GW20_ENDPOINT/_KEY" +)] +pub async fn gateway20_change_feed_latest_version() { + let Some((_endpoint, _key)) = live_credentials() else { + return; + }; + // Intentionally empty — see the test docs above for why. +} + +/// Verifies that diagnostics are populated for SDK-issued requests routed +/// through Gateway 2.0. +/// +/// TODO: extend this test to assert `TransportKind::Gateway20` once +/// `CosmosDiagnostics` surfaces the driver transport kind. Today the SDK +/// `CosmosDiagnostics` only carries `activity_id` and `server_duration_ms`, +/// so the strongest behavioural assertion we can make is that those fields +/// are populated when the request was routed through the Gateway 2.0 +/// pipeline. +#[tokio::test] +#[cfg_attr( + not(test_category = "gateway20"), + ignore = "requires test_category 'gateway20' and AZURE_COSMOS_GW20_ENDPOINT/_KEY" +)] +pub async fn gateway20_diagnostics_validation() -> Result<(), Box> { + let Some((endpoint, key)) = live_credentials() else { + return Ok(()); + }; + + let client = build_client(&endpoint, &key, false).await?; + let (db_name, container) = provision_database_and_container(&client).await?; + + let pk_value = format!("pk-{}", azure_core::Uuid::new_v4()); + let item = Gw20TestItem { + id: "diag-item".into(), + pk: pk_value.clone(), + value: 99, + label: "diag".into(), + }; + container + .create_item(&pk_value, "diag-item", &item, None) + .await?; + + let read_resp = container + .read_item::(&pk_value, "diag-item", None) + .await?; + let diagnostics = read_resp.diagnostics(); + assert!( + diagnostics.activity_id().is_some(), + "expected activity_id to be populated for a Gateway 2.0 request" + ); + assert!( + diagnostics.server_duration_ms().is_some(), + "expected server_duration_ms to be populated for a Gateway 2.0 request" + ); + + drop_database(&client, &db_name).await; + Ok(()) +} + +/// Verifies the operator override at the SDK boundary: when the operator +/// disables Gateway 2.0 via [`CosmosClientBuilder::with_gateway20_disabled`], +/// every request must route through the standard gateway even though the +/// account advertises a Gateway 2.0 endpoint. +/// +/// TODO: tighten the assertion to inspect `TransportKind::StandardGateway` +/// in the diagnostics once the SDK exposes the driver transport kind. +/// +/// [`CosmosClientBuilder::with_gateway20_disabled`]: azure_data_cosmos::CosmosClientBuilder::with_gateway20_disabled +#[tokio::test] +#[cfg_attr( + not(test_category = "gateway20"), + ignore = "requires test_category 'gateway20' and AZURE_COSMOS_GW20_ENDPOINT/_KEY" +)] +pub async fn gateway20_operator_override_at_sdk_boundary() -> Result<(), Box> +{ + let Some((endpoint, key)) = live_credentials() else { + return Ok(()); + }; + + let client = build_client(&endpoint, &key, true).await?; + let (db_name, container) = provision_database_and_container(&client).await?; + + let pk_value = format!("pk-{}", azure_core::Uuid::new_v4()); + let item = Gw20TestItem { + id: "override-item".into(), + pk: pk_value.clone(), + value: 7, + label: "override".into(), + }; + container + .create_item(&pk_value, "override-item", &item, None) + .await?; + + let read_resp = container + .read_item::(&pk_value, "override-item", None) + .await?; + let diagnostics = read_resp.diagnostics(); + assert!(diagnostics.activity_id().is_some()); + + drop_database(&client, &db_name).await; + Ok(()) +} + +/// Provisions a fresh database + 3-component HPK container and returns the +/// db name (for cleanup) and a container client. Mirrors +/// [`provision_database_and_container`] but uses +/// `(/tenantId, /userId, /sessionId)` as the partition key paths so the +/// container exercises hierarchical partitioning end-to-end. +async fn provision_database_and_hpk_container( + client: &CosmosClient, +) -> Result<(String, azure_data_cosmos::clients::ContainerClient), Box> { + let unique = azure_core::Uuid::new_v4(); + let db_name = format!("gw20-test-db-{unique}"); + let container_name = format!("gw20-test-hpk-container-{unique}"); + + client.create_database(&db_name, None).await?; + let db_client = client.database_client(&db_name); + + let pk_def = PartitionKeyDefinition::from(("/tenantId", "/userId", "/sessionId")); + let properties = ContainerProperties::new(container_name.clone(), pk_def); + db_client.create_container(properties, None).await?; + let container_client = db_client.container_client(&container_name).await?; + + Ok((db_name, container_client)) +} + +#[derive(Debug, Deserialize, Serialize, PartialEq, Eq, Clone)] +struct Gw20HpkItem { + id: String, + #[serde(rename = "tenantId")] + tenant_id: String, + #[serde(rename = "userId")] + user_id: String, + #[serde(rename = "sessionId")] + session_id: String, + value: i64, +} + +/// Round-trip exercises Gateway 2.0 against a 3-component hierarchical +/// partition key container, asserting both the **full PK** point-op path +/// and the **partial PK** range-dispatch path (`x-ms-thinclient-range-min` +/// / `-max`) discussed in the Gateway 2.0 spec test matrix +/// ("HPK + Gateway 2.0: full vs partial PK"). +/// +/// 1. Inserts items spread across two tenants × two users. +/// 2. Reads each item back via its full 3-component PK (point op → EPK token). +/// 3. Queries with a **1-component prefix** (`tenantId` only) and asserts +/// the items for that tenant come back across however many pages the +/// proxy fans out into. +/// +/// The point-vs-range header emission is asserted at unit level in +/// `gateway20_dispatch::tests`; this E2E test guards the SDK-public surface +/// against regressions where partial-PK queries silently degrade to +/// single-partition or fail. +/// +/// TODO: tighten the diagnostics check to assert `TransportKind::Gateway20` +/// once the SDK surfaces the driver transport kind. +#[tokio::test] +#[cfg_attr( + not(test_category = "gateway20"), + ignore = "requires test_category 'gateway20' and AZURE_COSMOS_GW20_ENDPOINT/_KEY" +)] +pub async fn gateway20_hpk_full_and_partial_partition_key_round_trip( +) -> Result<(), Box> { + use azure_data_cosmos::{PartitionKey, PartitionKeyValue}; + + let Some((endpoint, key)) = live_credentials() else { + return Ok(()); + }; + + let client = build_client(&endpoint, &key, false).await?; + let (db_name, container) = provision_database_and_hpk_container(&client).await?; + + let target_tenant = format!("tenant-{}", azure_core::Uuid::new_v4()); + let other_tenant = format!("tenant-{}", azure_core::Uuid::new_v4()); + + // Two users × two sessions per tenant => 4 items per tenant. + let mut expected_target_ids = Vec::new(); + for tenant in [target_tenant.as_str(), other_tenant.as_str()] { + for user_idx in 0..2 { + for session_idx in 0..2 { + let user_id = format!("user-{user_idx}"); + let session_id = format!("session-{session_idx}"); + let id = format!("{tenant}-{user_id}-{session_id}"); + if tenant == target_tenant { + expected_target_ids.push(id.clone()); + } + let item = Gw20HpkItem { + id: id.clone(), + tenant_id: tenant.to_string(), + user_id: user_id.clone(), + session_id: session_id.clone(), + value: i64::from(user_idx * 10 + session_idx), + }; + // PartitionKey tuple impls require owned types (the underlying + // `PartitionKeyValue: From<&'static str>` impl is the only + // borrow-friendly one) — clone strings into the tuple. + let pk = PartitionKey::from((tenant.to_string(), user_id, session_id)); + container.create_item(pk, &id, &item, None).await?; + } + } + } + + // Full HPK point read (3-of-3 components → EPK token path). + let full_pk = PartitionKey::from(( + target_tenant.clone(), + "user-0".to_string(), + "session-0".to_string(), + )); + let full_id = format!("{target_tenant}-user-0-session-0"); + let read_resp = container + .read_item::(full_pk, &full_id, None) + .await?; + let item: Gw20HpkItem = read_resp.into_model()?; + assert_eq!(item.id, full_id); + assert_eq!(item.tenant_id, target_tenant); + + // Partial HPK query (1-of-3 components → range header path). + // PartitionKey only has tuple From-impls for 2 and 3 components; for a + // single-component prefix, construct it from a Vec so + // the dispatcher sees a 1-component value against a 3-path container. + let partial_pk = PartitionKey::from(vec![PartitionKeyValue::from(target_tenant.clone())]); + let query = Query::from("SELECT * FROM c"); + let mut pages = container + .query_items::(query, partial_pk, None)? + .into_pages(); + + let mut returned_ids: Vec = Vec::new(); + let mut pages_seen = 0_usize; + while let Some(page) = pages.next().await { + let page = page?; + pages_seen += 1; + assert!(page.diagnostics().activity_id().is_some()); + for it in page.items() { + assert_eq!( + it.tenant_id, target_tenant, + "partial-PK query must not bleed across tenants" + ); + returned_ids.push(it.id.clone()); + } + } + assert!(pages_seen >= 1, "expected at least one query page"); + expected_target_ids.sort(); + returned_ids.sort(); + assert_eq!(returned_ids, expected_target_ids); + + drop_database(&client, &db_name).await; + Ok(()) +} diff --git a/sdk/cosmos/azure_data_cosmos/tests/emulator_tests/mod.rs b/sdk/cosmos/azure_data_cosmos/tests/emulator_tests/mod.rs index 02b9779b1b3..bc8689e0c44 100644 --- a/sdk/cosmos/azure_data_cosmos/tests/emulator_tests/mod.rs +++ b/sdk/cosmos/azure_data_cosmos/tests/emulator_tests/mod.rs @@ -10,6 +10,7 @@ mod cosmos_items; mod cosmos_offers; mod cosmos_proxy; mod cosmos_query; +mod gateway20_e2e; #[path = "../framework/mod.rs"] mod framework; diff --git a/sdk/cosmos/azure_data_cosmos/tests/framework/mock_account.rs b/sdk/cosmos/azure_data_cosmos/tests/framework/mock_account.rs index 49fdb7dbad2..90e172ff713 100644 --- a/sdk/cosmos/azure_data_cosmos/tests/framework/mock_account.rs +++ b/sdk/cosmos/azure_data_cosmos/tests/framework/mock_account.rs @@ -4,8 +4,8 @@ //! Helpers for building mock `GetDatabaseAccount` responses in fault injection tests. // cSpell: disable -use azure_core::http::{headers::Headers, StatusCode}; -use azure_data_cosmos::fault_injection::CustomResponse; +use azure_core::http::StatusCode; +use azure_data_cosmos::fault_injection::{CustomResponse, CustomResponseBuilder}; use azure_data_cosmos::regions::Region; /// Builds a [`CustomResponse`] containing a valid `AccountProperties` JSON payload @@ -48,11 +48,9 @@ pub fn mock_database_account_response_for_account( multi_write: bool, ) -> CustomResponse { let body = mock_database_account_json(account_name, writable, readable, multi_write); - CustomResponse { - status_code: StatusCode::Ok, - headers: Headers::new(), - body: body.into_bytes(), - } + CustomResponseBuilder::new(StatusCode::Ok) + .with_body(body.into_bytes()) + .build() } /// Builds a valid `AccountProperties` JSON string with the specified regions. @@ -119,7 +117,7 @@ mod tests { ); let value: serde_json::Value = - serde_json::from_slice(&response.body).expect("should deserialize"); + serde_json::from_slice(response.body()).expect("should deserialize"); let writable = value["writableLocations"].as_array().unwrap(); let readable = value["readableLocations"].as_array().unwrap(); @@ -139,7 +137,7 @@ mod tests { ); let value: serde_json::Value = - serde_json::from_slice(&response.body).expect("should deserialize"); + serde_json::from_slice(response.body()).expect("should deserialize"); assert!(value["enableMultipleWriteLocations"].as_bool().unwrap()); } diff --git a/sdk/cosmos/azure_data_cosmos_driver/CHANGELOG.md b/sdk/cosmos/azure_data_cosmos_driver/CHANGELOG.md index 3a1550f16f5..d0fe9195958 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/CHANGELOG.md +++ b/sdk/cosmos/azure_data_cosmos_driver/CHANGELOG.md @@ -4,6 +4,8 @@ ### Features Added +- Added Gateway 2.0 transport support, enabled by default. The new transport routes data-plane requests through a regional Gateway 2.0 proxy that forwards RNTBD-over-HTTP/2 to the backend. Set `ConnectionPoolOptionsBuilder::with_gateway20_disabled(true)` to fall back to the direct gateway transport. Note that Gateway 2.0 is **not currently covered by latency SLAs** and may impose higher per-request latency. ([#4319](https://github.com/Azure/azure-sdk-for-rust/pull/4319)) + ### Breaking Changes ### Bugs Fixed diff --git a/sdk/cosmos/azure_data_cosmos_driver/build.rs b/sdk/cosmos/azure_data_cosmos_driver/build.rs index 03429fb6dca..a631bbe5814 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/build.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/build.rs @@ -6,5 +6,7 @@ // unknown cfg names are warned/denied unless explicitly declared via check-cfg. fn main() { // Allow `#[cfg_attr(not(test_category = "..."), ignore)]` in `tests/*.rs`. - println!("cargo:rustc-check-cfg=cfg(test_category, values(\"emulator\", \"multi_write\"))"); + println!( + "cargo:rustc-check-cfg=cfg(test_category, values(\"emulator\", \"multi_write\", \"gateway20\"))" + ); } diff --git a/sdk/cosmos/azure_data_cosmos_driver/docs/GATEWAY_20_SPEC.md b/sdk/cosmos/azure_data_cosmos_driver/docs/GATEWAY_20_SPEC.md index dc09653487d..149285dfa8c 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/docs/GATEWAY_20_SPEC.md +++ b/sdk/cosmos/azure_data_cosmos_driver/docs/GATEWAY_20_SPEC.md @@ -90,10 +90,10 @@ Gateway 2.0 routing is decided **once per logical operation** (a point operation ```text gateway20_suppressed = options.gateway20_disabled - || !account.has_thin_client_endpoints() + || !account.has_gateway20_endpoints() ``` -When `gateway20_suppressed` is `false` (the default whenever the account advertises Gateway 2.0 endpoints and the operator has not flipped the override), the request routes through Gateway 2.0. When it is `true`, the request falls through to Gateway V1. The account-side check (`has_thin_client_endpoints()`) reads the cached account metadata. The client-side check (`gateway20_disabled`) is the only public toggle. +When `gateway20_suppressed` is `false` (the default whenever the account advertises Gateway 2.0 endpoints and the operator has not flipped the override), the request routes through Gateway 2.0. When it is `true`, the request falls through to Gateway V1. The account-side check (`has_gateway20_endpoints()`) reads the cached account metadata. The client-side check (`gateway20_disabled`) is the only public toggle. ### 3.2 Operator override: `CosmosClientOptions::gateway20_disabled` @@ -205,6 +205,15 @@ The Rust deserializer **must** treat the RNTBD response metadata-token stream as - **Unknown token type IDs MUST be silently skipped** (consume `length` bytes and continue) — the deserializer must NOT panic, return an error, or fail the response, and must NOT log per-token (silent skip is the contract). The proxy is free to add new metadata tokens at any time and the driver must remain forward-compatible across proxy upgrades that ship before the corresponding Rust release. This silent-tolerance behavior is the *implementation* of the `IgnoreUnknownRntbdTokens` capability bit advertised over the `x-ms-cosmos-sdk-supportedcapabilities` header (see "SDK-supported-capabilities advertisement" below) — the proxy/backend assumes the SDK will not surface or warn on unknown tokens, so per-token logging is unnecessary noise. - **Inverse contract on the request side**: the request serializer drops headers that appear in `thinClientProxyExcludedSet` (see §"RNTBD Request Wire Format" Notes column). That set enumerates headers the proxy does not understand on the inbound RNTBD frame; emitting them would be either ignored or rejected. +##### Continuation-token format (request and response) + +Continuation tokens are **opaque server-issued strings** in both directions; the SDK never parses, validates, or rewrites them. The wire format is a length-prefixed UTF-8 string token mirroring Java's RNTBD encoding: + +- **Request side** — `RntbdRequestToken::ContinuationToken` (ID `0x0006`, `TokenType::String`). When the inbound HTTP request carries `x-ms-continuation`, the wrap path serializes the value verbatim into the RNTBD metadata stream and **strips** the header from the outer HTTP request (the outer body is the RNTBD frame; metadata never duplicates onto outer headers). Empty values are passed through as zero-length string tokens — the wrap path does not infer intent from emptiness, matching the unwrap side and the .NET/Java behavior. +- **Response side** — `RntbdResponseToken::ContinuationToken` (ID `0x0003`, `TokenType::String`). The unwrap path forwards the token value verbatim into the synthetic HTTP response's `x-ms-continuation` header. + +Identical semantics to .NET (`ThinClientStoreClient.cs` / `ThinClientTransportSerializer.cs`, which contain no continuation-specific logic and rely on the standard gateway path) and Java (`RntbdRequestHeader.ContinuationToken` is *not* in `thinClientProxyExcludedSet`, so it traverses the same encode/decode path as standard direct-mode RNTBD). There is no Gateway-2.0-specific token format, base64 wrapper, or version prefix; pagination cursors round-trip byte-for-byte. + Phase 6's "RNTBD unknown-token tolerance" unit test pins this behavior: a hand-crafted response frame containing a synthetic unrecognized token ID must round-trip without error and surface every recognized token correctly. #### SDK-supported-capabilities advertisement @@ -215,6 +224,18 @@ Phase 1 must change the emitted value to the bitmask `(PartitionMerge | IgnoreUn The `IgnoreUnknownRntbdTokens` bit is the contract that backs the silent-skip behavior in "Metadata token filtering" above: the proxy/backend uses this advertisement to decide whether it is safe to add new RNTBD tokens without coordinating with this SDK release. Advertising the bit while *also* failing or warning on unknown tokens would be a contract violation; advertising `"0"` while silently skipping unknown tokens is "merely conservative" but causes the proxy to assume zero forward-compat tolerance — both are wrong. Phase 1 must reconcile both ends. +##### Capability bit composition (Rust = `9`, Java = `11`) + +The bitmask the Rust driver advertises is **`9`** (`PartitionMerge | IgnoreUnknownRntbdTokens`). Pinned in `azure_data_cosmos_driver/src/driver/transport/cosmos_headers.rs:16-25` with a `const _: () = assert!(SUPPORTED_CAPABILITIES_BITS == 9);` invariant. The bits are sourced from .NET `SDKSupportedCapabilities.cs` and the C++ proxy enum: + +| Bit | Decimal | Capability | Rust advertises | Java advertises | Notes | +| ---- | ------- | ---------------------- | --------------- | --------------- | ------------------------------------------------------------------------------------------------------------------------------------ | +| 0 | 1 | `PartitionMerge` | yes | yes | Forward-compat with merged partition-key ranges; required for Gateway 2.0 because the proxy may surface merged ranges in routing. | +| 1 | 2 | (Java-only capability; name per Java `SDKSupportedCapabilities`) | **no** | yes | Java opts in to an additional capability the Rust driver does not yet consume. Unilaterally advertising it without honoring the corresponding behavior could cause mis-framing or unexpected proxy behavior. Verify the exact capability name against Java/.NET source before adding. Track in a follow-up if/when the driver grows the corresponding support. | +| 3 | 8 | `IgnoreUnknownRntbdTokens` | yes | yes | Forward-compat with new RNTBD response tokens added by future proxy/backend versions; backed by the silent-skip behavior in "Metadata token filtering" above. | + +Total: Rust `1 | 8 = 9`; Java `1 | 2 | 8 = 11`. The two-bit gap is intentional and conservative — the Rust driver only advertises capabilities it actually implements end-to-end. Adding bit 1 (or any future bit) requires implementing the corresponding behavior first, then incrementing the constant in `cosmos_headers.rs` and re-pinning `Phase 6`'s header-value test. + Phase 6 test coverage: assert the header value emitted on Gateway 2.0 (and standard Gateway) requests is the expected bitmask string, not `"0"`. #### RNTBD Request Wire Format @@ -231,7 +252,7 @@ RntbdUUID.encode(activityId, out); // two longs | Offset | Size | Field | Encoding | Notes | | --- | --- | --- | --- | --- | | 0 | 4 | Total message length | uint32 LE | **Inclusive** of the 4 length bytes themselves (matches Java `writeIntLE` semantics). | -| 4 | 2 | Resource type | uint16 LE | `writeShortLE(resourceType.id())` — narrower than direct-mode RNTBD's uint32 because thin-client IDs fit in 16 bits. | +| 4 | 2 | Resource type | uint16 LE | `writeShortLE(resourceType.id())` — narrower than direct-mode RNTBD's uint32 because Gateway 2.0 IDs fit in 16 bits. | | 6 | 2 | Operation type | uint16 LE | `writeShortLE(operationType.id())` — same rationale. | | 8 | 16 | Activity ID | UUID, two uint64 LE | Java writes `(mostSignificantBits, leastSignificantBits)` as two little-endian `long`s — **this is not RFC 4122 byte order**. Worked example for UUID `0a1b2c3d-4e5f-6789-abcd-ef0123456789`: `mostSignificantBits = 0x0a1b2c3d_4e5f_6789` → LE bytes `89 67 5f 4e 3d 2c 1b 0a`; `leastSignificantBits = 0xabcd_ef01_2345_6789` → LE bytes `89 67 45 23 01 ef cd ab`. The on-the-wire 16-byte sequence is the MSB bytes followed by the LSB bytes. | | 24 | var | Metadata tokens | Token stream | Filtered by `thinClientProxyExcludedSet` (see §Phase 2 header naming). | @@ -310,7 +331,7 @@ These are wire-level HTTP/2 request headers on the outer POST to the proxy. They | `x-ms-documentdb-partitionkey` | existing `PARTITION_KEY` constant (SDK) | JSON-encoded partition-key value | Point ops AND single-logical-partition query ops, alongside `x-ms-effective-partition-key` | | `x-ms-thinclient-range-min` | **NEW** — `GATEWAY20_RANGE_MIN` (driver) | Lower bound of EPK range | Feed / cross-partition ops only | | `x-ms-thinclient-range-max` | **NEW** — `GATEWAY20_RANGE_MAX` (driver) | Upper bound of EPK range | Feed / cross-partition ops only | -| `x-ms-cosmos-use-thinclient` | **NEW** — `GATEWAY20_USE_THINCLIENT` (driver) | Instructs account-metadata response to advertise thin-client endpoints | Account metadata fetches only | +| `x-ms-cosmos-use-thinclient` | **NEW** — `GATEWAY20_USE_THINCLIENT` (driver) | Instructs account-metadata response to advertise Gateway 2.0 endpoints | Account metadata fetches only | > Wire-header strings (`x-ms-thinclient-*`) are server-defined and unchanged; the Rust-side identifiers use the `GATEWAY20_*` prefix. @@ -412,7 +433,7 @@ EDIT sdk/cosmos/azure_data_cosmos/src/... — Replace SDK-side get_hashe - **Verify** account metadata cache parses `thinClientReadableLocations` / `thinClientWritableLocations` into `CosmosEndpoint::gateway20_url` - **Confirm** `build_account_endpoint_state()` constructs `CosmosEndpoint::regional_with_gateway20()` correctly in multi-region accounts (existing tests at `routing_systems.rs:218–289` already cover this) -- **Verify** `AccountProperties::has_thin_client_endpoints()` is used as the gating signal per §3.1 +- **Verify** `AccountProperties::has_gateway20_endpoints()` is used as the gating signal per §3.1 - **Add** `x-ms-cosmos-use-thinclient` request header on account metadata fetches (new code) - **Test** endpoint discovery with live account that has gateway 2.0 enabled (handled by Phase 6 live pipeline) @@ -434,7 +455,7 @@ Account metadata response includes: #### Files Changed ``` -EDIT src/driver/cache/account_metadata_cache.rs — Verify thin client endpoint parsing (audit only) +EDIT src/driver/cache/account_metadata_cache.rs — Verify Gateway 2.0 endpoint parsing (audit only) EDIT src/driver/transport/cosmos_headers.rs — Add x-ms-cosmos-use-thinclient header (NEW) TEST src/driver/routing/routing_systems.rs — Add tests for read/write pairing edge cases ``` @@ -540,7 +561,8 @@ A **new dedicated CI pipeline** is required for gateway 2.0 live tests. Gateway | Action | File | Purpose | | --- | --- | --- | -| NEW | `sdk/cosmos/ci-gateway20.yml` | Gateway 2.0 live tests pipeline definition (uses pre-provisioned account) | +| EDIT | `sdk/cosmos/ci.yml` | Add a second `LiveTestMatrixConfigs` entry (`Cosmos_gateway20_live_test`) that points at `live-gateway20-matrix.json`, plus an `EnvVars` block that injects `AZURE_COSMOS_GW20_ENDPOINT` / `AZURE_COSMOS_GW20_KEY` from the `azure-sdk-tests-cosmos` service connection. Mirrors Java's `sdk/cosmos/tests.yml` thin-client setup. | +| NEW | `sdk/cosmos/live-gateway20-matrix.json` | Gateway 2.0 live test matrix (single-region + multi-region; `testCategory` = `gateway20` / `gateway20_multi_region`). The pre-provisioned account is supplied via the env vars above; the matrix's `ArmTemplateParameters` block is preserved so the deploy step still runs even though the per-run account is unused. | | EDIT | `sdk/cosmos/live-platform-matrix.json` | Add gateway 2.0 test matrix entry | #### Test Coverage Matrix diff --git a/sdk/cosmos/azure_data_cosmos_driver/docs/TRANSPORT_PIPELINE_SPEC.md b/sdk/cosmos/azure_data_cosmos_driver/docs/TRANSPORT_PIPELINE_SPEC.md index 09361ad2736..29771225e6d 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/docs/TRANSPORT_PIPELINE_SPEC.md +++ b/sdk/cosmos/azure_data_cosmos_driver/docs/TRANSPORT_PIPELINE_SPEC.md @@ -2242,11 +2242,11 @@ HTTP/2 just uses a single `Arc` like HTTP/1.1. endpoints are detected and used. No sharding yet — stream limit may be hit under high load. > **Note on ALPN probing (§6.0):** The initial Step 5 implementation uses configuration flags -> (`is_http2_allowed`, `is_gateway20_allowed`) and `AccountProperties` metadata +> (`is_http2_allowed`, `gateway20_disabled`) and `AccountProperties` metadata > (`thinClient*Locations`) to determine the transport strategy, rather than runtime ALPN > negotiation against the gateway. This is sufficient because: > (1) reqwest with `http2` feature already performs ALPN automatically for `Http2Preferred`, -> (2) Gateway 2.0 is definitively identified by the presence of thin-client locations in +> (2) Gateway 2.0 is definitively identified by the presence of Gateway 2.0 locations in > account metadata, and (3) `http2_prior_knowledge()` for `Http2Only` skips ALPN entirely > (h2 is guaranteed). Runtime probing may be revisited if a use case arises where the > configuration-based approach is insufficient. @@ -2317,7 +2317,7 @@ Cut over all remaining operations and remove the old pipeline code. | 10.7 | Move fault injection tests to driver-level APIs | `tests/` | | 10.8 | Full integration test pass | `tests/` | -**What works after Step 10**: `azure_data_cosmos` is a thin client layer that builds +**What works after Step 10**: `azure_data_cosmos` is a thin SDK wrapper layer that builds `CosmosOperation` values and delegates all execution to the driver. Duplicate pipeline, retry, and routing code is removed. diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/constants.rs b/sdk/cosmos/azure_data_cosmos_driver/src/constants.rs new file mode 100644 index 00000000000..de5f39c5ad5 --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos_driver/src/constants.rs @@ -0,0 +1,108 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Don't spell-check header names (which should start with 'x-'). +// cSpell:disable + +//! Driver-level Cosmos DB constants. +//! +//! This module owns the canonical wire-name strings for the Gateway 2.0 +//! HTTP/2 outer headers. The wire strings retain the historical +//! `x-ms-thinclient-*` form because the proxy is server-defined; only the +//! Rust identifier follows the `GATEWAY20_*` naming convention. + +use azure_core::http::headers::HeaderName; + +/// Gateway 2.0 proxy operation-type header. +/// +/// Contains the numeric operation type on every Gateway 2.0 request. +pub const GATEWAY20_OPERATION_TYPE: HeaderName = + HeaderName::from_static("x-ms-thinclient-proxy-operation-type"); + +/// Gateway 2.0 proxy resource-type header. +/// +/// Contains the numeric resource type on every Gateway 2.0 request. +pub const GATEWAY20_RESOURCE_TYPE: HeaderName = + HeaderName::from_static("x-ms-thinclient-proxy-resource-type"); + +/// Effective Partition Key header. +/// +/// Sent for point Document operations only. +pub const EFFECTIVE_PARTITION_KEY: HeaderName = + HeaderName::from_static("x-ms-effective-partition-key"); + +/// Lower bound of the EPK range. +/// +/// Sent for feed and cross-partition operations only. +pub const GATEWAY20_RANGE_MIN: HeaderName = HeaderName::from_static("x-ms-thinclient-range-min"); + +/// Upper bound of the EPK range. +/// +/// Sent for feed and cross-partition operations only. +pub const GATEWAY20_RANGE_MAX: HeaderName = HeaderName::from_static("x-ms-thinclient-range-max"); + +/// Account-metadata fetch hint. +/// +/// Instructs the response to advertise Gateway 2.0 endpoints. +pub const GATEWAY20_USE_THINCLIENT: HeaderName = + HeaderName::from_static("x-ms-cosmos-use-thinclient"); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn constants_match_expected_wire_strings() { + let cases = [ + ( + GATEWAY20_OPERATION_TYPE, + HeaderName::from_static("x-ms-thinclient-proxy-operation-type"), + ), + ( + GATEWAY20_RESOURCE_TYPE, + HeaderName::from_static("x-ms-thinclient-proxy-resource-type"), + ), + ( + EFFECTIVE_PARTITION_KEY, + HeaderName::from_static("x-ms-effective-partition-key"), + ), + ( + GATEWAY20_RANGE_MIN, + HeaderName::from_static("x-ms-thinclient-range-min"), + ), + ( + GATEWAY20_RANGE_MAX, + HeaderName::from_static("x-ms-thinclient-range-max"), + ), + ( + GATEWAY20_USE_THINCLIENT, + HeaderName::from_static("x-ms-cosmos-use-thinclient"), + ), + ]; + + for (actual, expected) in cases { + assert_eq!(actual, expected); + } + } + + #[test] + fn constants_have_distinct_wire_strings() { + let constants = [ + ("GATEWAY20_OPERATION_TYPE", GATEWAY20_OPERATION_TYPE), + ("GATEWAY20_RESOURCE_TYPE", GATEWAY20_RESOURCE_TYPE), + ("EFFECTIVE_PARTITION_KEY", EFFECTIVE_PARTITION_KEY), + ("GATEWAY20_RANGE_MIN", GATEWAY20_RANGE_MIN), + ("GATEWAY20_RANGE_MAX", GATEWAY20_RANGE_MAX), + ("GATEWAY20_USE_THINCLIENT", GATEWAY20_USE_THINCLIENT), + ]; + + for (index, (left_name, left_header)) in constants.iter().enumerate() { + for (right_name, right_header) in constants.iter().skip(index + 1) { + assert_ne!( + left_header, right_header, + "{left_name} and {right_name} must not share a wire string" + ); + } + } + } +} diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/diagnostics/diagnostics_context.rs b/sdk/cosmos/azure_data_cosmos_driver/src/diagnostics/diagnostics_context.rs index a1534b7409a..043648ee0e7 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/src/diagnostics/diagnostics_context.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/src/diagnostics/diagnostics_context.rs @@ -159,7 +159,7 @@ pub enum TransportSecurity { /// The concrete transport kind used for a request. /// -/// This distinguishes the standard gateway path from Gateway 2.0 thin-client +/// This distinguishes the standard gateway path from Gateway 2.0 /// routing while keeping TLS/emulator concerns in [`TransportSecurity`]. #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash, Serialize)] #[serde(rename_all = "snake_case")] @@ -169,7 +169,7 @@ pub enum TransportKind { #[default] Gateway, - /// Gateway 2.0 thin-client transport. + /// Gateway 2.0 transport. Gateway20, } diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/cache/account_metadata_cache.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/cache/account_metadata_cache.rs index dcd0a412fb8..729f6cdb2aa 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/src/driver/cache/account_metadata_cache.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/cache/account_metadata_cache.rs @@ -145,15 +145,23 @@ pub(crate) struct AccountProperties { /// Raw JSON string containing query engine feature/configuration flags. pub query_engine_configuration: String, - /// Regional Gateway 2.0 endpoints accepting writes (thin client mode). + /// Regional Gateway 2.0 endpoints accepting writes. /// When present, indicates that Gateway 2.0 should be used for the /// dataplane transport instead of the standard gateway endpoint. + /// + /// The Rust field name retains the `thin_client_*` prefix because the + /// struct uses `#[serde(rename_all = "camelCase")]` to deserialize from + /// the wire-defined property `thinClientWritableLocations`. Renaming the + /// field would break the serde mapping. #[serde(default)] pub thin_client_writable_locations: Vec, - /// Regional Gateway 2.0 endpoints for reads (thin client mode). + /// Regional Gateway 2.0 endpoints for reads. /// When present, indicates that Gateway 2.0 should be used for the /// dataplane transport instead of the standard gateway endpoint. + /// + /// See note on `thin_client_writable_locations` for why the field name + /// retains the `thin_client_*` prefix. #[serde(default)] pub thin_client_readable_locations: Vec, @@ -184,25 +192,25 @@ impl AccountProperties { .collect() } - /// Returns `true` if Gateway 2.0 (thin client) endpoints are available. + /// Returns `true` if Gateway 2.0 endpoints are available. /// - /// When thin client locations are present in the account properties, + /// When Gateway 2.0 locations are present in the account properties, /// the driver should use Gateway 2.0 for the dataplane transport. - pub(crate) fn has_thin_client_endpoints(&self) -> bool { + pub(crate) fn has_gateway20_endpoints(&self) -> bool { !self.thin_client_writable_locations.is_empty() || !self.thin_client_readable_locations.is_empty() } - /// Returns thin client (Gateway 2.0) writable locations, if any. - pub(crate) fn thin_client_writable_regions(&self) -> Vec { + /// Returns Gateway 2.0 writable locations, if any. + pub(crate) fn gateway20_writable_regions(&self) -> Vec { self.thin_client_writable_locations .iter() .map(|loc| loc.name.clone()) .collect() } - /// Returns thin client (Gateway 2.0) readable locations, if any. - pub(crate) fn thin_client_readable_regions(&self) -> Vec { + /// Returns Gateway 2.0 readable locations, if any. + pub(crate) fn gateway20_readable_regions(&self) -> Vec { self.thin_client_readable_locations .iter() .map(|loc| loc.name.clone()) diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/cosmos_driver.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/cosmos_driver.rs index 5f74982b439..25810111299 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/src/driver/cosmos_driver.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/cosmos_driver.rs @@ -750,7 +750,7 @@ impl CosmosDriver { account_endpoint, default_endpoint, refresh_callback, - runtime.connection_pool().is_gateway20_allowed(), + !runtime.connection_pool().gateway20_disabled(), endpoint_unavailability_ttl, options.preferred_regions().to_vec(), )); diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/pipeline/components.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/pipeline/components.rs index 6a7479de9d6..bc1b15e189a 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/src/driver/pipeline/components.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/pipeline/components.rs @@ -17,9 +17,12 @@ use crate::{ driver::{ jitter::with_jitter, routing::{CosmosEndpoint, LocationIndex}, - transport::AuthorizationContext, + transport::{AuthorizationContext, EndpointKey}, + }, + models::{ + CosmosResponseHeaders, CosmosStatus, DefaultConsistencyLevel, OperationType, PartitionKey, + PartitionKeyDefinition, }, - models::{CosmosResponseHeaders, CosmosStatus}, options::Region, }; @@ -47,6 +50,8 @@ pub(crate) struct RoutingDecision { pub endpoint: CosmosEndpoint, /// The concrete URL selected for this attempt. pub selected_url: Url, + /// The connection-pool key matching the selected URL's authority. + pub endpoint_key: EndpointKey, /// The transport mode for this attempt. pub transport_mode: TransportMode, } @@ -177,6 +182,16 @@ pub(crate) struct TransportRequest { pub method: Method, /// The endpoint selected for this attempt. pub endpoint: CosmosEndpoint, + /// The routed transport mode for this attempt. + pub transport_mode: TransportMode, + /// The operation type being dispatched. + pub operation_type: OperationType, + /// Partition key for item-scoped Gateway 2.0 dispatch. + pub partition_key: Option, + /// Partition key definition for effective partition key computation. + pub partition_key_definition: Option, + /// Effective consistency resolved from account default and read options. + pub effective_consistency: DefaultConsistencyLevel, /// The fully resolved URL for this attempt. pub url: Url, /// Headers to send (includes operation-specific and attempt-specific headers). diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/pipeline/operation_pipeline.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/pipeline/operation_pipeline.rs index fc5ad4755fe..e66a583eb8d 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/src/driver/pipeline/operation_pipeline.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/pipeline/operation_pipeline.rs @@ -22,7 +22,10 @@ use crate::{ request_header_names, AccountEndpoint, ActivityId, CosmosOperation, CosmosResponse, Credential, DefaultConsistencyLevel, OperationType, SessionToken, SubStatusCode, }, - options::{OperationOptionsView, ReadConsistencyStrategy, ThroughputControlGroupSnapshot}, + options::{ + resolve_effective_consistency, OperationOptionsView, ReadConsistencyStrategy, + ThroughputControlGroupSnapshot, + }, }; use super::{ @@ -34,8 +37,9 @@ use super::{ }; use crate::driver::transport::{ + is_operation_supported_by_gateway20, transport_pipeline::{execute_transport_pipeline, TransportPipelineContext}, - AuthorizationContext, CosmosTransport, + AuthorizationContext, CosmosTransport, EndpointKey, }; /// Executes a Cosmos DB operation through the new pipeline architecture. @@ -73,6 +77,8 @@ pub(crate) async fn execute_operation_pipeline( .read_consistency_strategy() .copied() .unwrap_or(ReadConsistencyStrategy::Default); + let effective_consistency = + resolve_effective_consistency(read_consistency_strategy, account_default_consistency); let session_consistency_active = !session_capturing_disabled && read_consistency_strategy.is_session_effective(account_default_consistency); let max_session_retries = options @@ -112,11 +118,13 @@ pub(crate) async fn execute_operation_pipeline( let location = location_state_store.snapshot(); // ── STAGE 2: Resolve endpoint ────────────────────────────────── + let account_name = account_endpoint.global_database_account_name(); let routing = resolve_endpoint( operation, &retry_state, &location, pipeline_type == PipelineType::DataPlane, + account_name.is_some(), location_state_store.endpoint_unavailability_ttl(), ); @@ -137,6 +145,7 @@ pub(crate) async fn execute_operation_pipeline( activity_id, execution_context, deadline, + effective_consistency, resolved_session_token: session_consistency_active .then(|| { session_manager.resolve_session_token( @@ -199,7 +208,8 @@ pub(crate) async fn execute_operation_pipeline( user_agent, pipeline_type, transport_security, - endpoint_key: routing.endpoint.endpoint_key(), + endpoint_key: routing.endpoint_key.clone(), + account_name: account_name.clone(), }, &mut diagnostics, ) @@ -337,6 +347,7 @@ fn resolve_endpoint( retry_state: &OperationRetryState, location: &LocationSnapshot, prefer_gateway20: bool, + account_name_present: bool, endpoint_unavailability_ttl: Duration, ) -> RoutingDecision { let account = location.account.as_ref(); @@ -373,16 +384,28 @@ fn resolve_endpoint( } let selected = selected.unwrap_or_else(|| account.default_endpoint.clone()); - let use_gateway20 = selected.uses_gateway20(prefer_gateway20); + let use_gateway20 = selected.uses_gateway20(prefer_gateway20) + && account_name_present + && is_operation_supported_by_gateway20( + operation.resource_type(), + operation.operation_type(), + ); let transport_mode = if use_gateway20 { TransportMode::Gateway20 } else { TransportMode::Gateway }; + let selected_url = selected.selected_url(use_gateway20).clone(); + let endpoint_key = if use_gateway20 { + EndpointKey::try_from(&selected_url).expect("selected URL must have a valid host and port") + } else { + selected.endpoint_key() + }; RoutingDecision { - selected_url: selected.selected_url(use_gateway20).clone(), + selected_url, endpoint: selected, + endpoint_key, transport_mode, } } @@ -432,6 +455,7 @@ struct TransportRequestContext<'a> { activity_id: &'a ActivityId, execution_context: ExecutionContext, deadline: Option, + effective_consistency: DefaultConsistencyLevel, resolved_session_token: Option, throughput_control: Option<&'a ThroughputControlGroupSnapshot>, } @@ -560,6 +584,13 @@ fn build_transport_request( Ok(TransportRequest { method, endpoint: ctx.routing.endpoint.clone(), + transport_mode: ctx.routing.transport_mode, + operation_type: operation.operation_type(), + partition_key: operation.partition_key().cloned(), + partition_key_definition: operation + .container() + .map(|container| container.partition_key_definition().clone()), + effective_consistency: ctx.effective_consistency, url, headers, body: operation.body().map(azure_core::Bytes::copy_from_slice), @@ -647,11 +678,13 @@ mod tests { driver::{ pipeline::components::{RoutingDecision, TransportMode}, routing::{AccountEndpointState, CosmosEndpoint, LocationIndex, LocationSnapshot}, + transport::EndpointKey, }, models::{ request_header_names, AccountReference, ActivityId, ContainerProperties, - ContainerReference, CosmosOperation, DatabaseReference, ItemReference, PartitionKey, - PartitionKeyDefinition, SystemProperties, ThroughputControlGroupName, + ContainerReference, CosmosOperation, DatabaseReference, DefaultConsistencyLevel, + ItemReference, PartitionKey, PartitionKeyDefinition, SystemProperties, + ThroughputControlGroupName, }, options::{PriorityLevel, ThroughputControlGroupSnapshot}, }; @@ -691,6 +724,7 @@ mod tests { CosmosEndpoint::global(Url::parse("https://test.documents.azure.com:443/").unwrap()); RoutingDecision { selected_url: endpoint.url().clone(), + endpoint_key: endpoint.endpoint_key(), endpoint, transport_mode: TransportMode::Gateway, } @@ -707,6 +741,7 @@ mod tests { activity_id: &activity_id, execution_context: ExecutionContext::Initial, deadline: None, + effective_consistency: DefaultConsistencyLevel::Session, resolved_session_token: None, throughput_control: None, }; @@ -728,6 +763,7 @@ mod tests { activity_id: &activity_id, execution_context: ExecutionContext::Initial, deadline: None, + effective_consistency: DefaultConsistencyLevel::Session, resolved_session_token: None, throughput_control: None, }; @@ -749,6 +785,7 @@ mod tests { activity_id: &activity_id, execution_context: ExecutionContext::Initial, deadline: None, + effective_consistency: DefaultConsistencyLevel::Session, resolved_session_token: None, throughput_control: None, }; @@ -775,6 +812,7 @@ mod tests { activity_id: &activity_id, execution_context: ExecutionContext::Retry, deadline: Some(std::time::Instant::now() + Duration::from_secs(5)), + effective_consistency: DefaultConsistencyLevel::Session, resolved_session_token: None, throughput_control: None, }; @@ -792,13 +830,16 @@ mod tests { fn build_transport_request_uses_routed_endpoint_url_directly() { let operation = CosmosOperation::read_database(DatabaseReference::from_name(test_account(), "mydb")); + let selected_url = + Url::parse("https://test-westus2-thin.documents.azure.com:444/").unwrap(); let routing = RoutingDecision { endpoint: CosmosEndpoint::regional_with_gateway20( "westus2".into(), Url::parse("https://test-westus2.documents.azure.com:443/").unwrap(), - Url::parse("https://test-westus2-thin.documents.azure.com:444/").unwrap(), + selected_url.clone(), ), - selected_url: Url::parse("https://test-westus2-thin.documents.azure.com:444/").unwrap(), + endpoint_key: EndpointKey::try_from(&selected_url).unwrap(), + selected_url, transport_mode: TransportMode::Gateway20, }; @@ -808,6 +849,7 @@ mod tests { activity_id: &activity_id, execution_context: ExecutionContext::Initial, deadline: None, + effective_consistency: DefaultConsistencyLevel::Session, resolved_session_token: None, throughput_control: None, }; @@ -824,11 +866,12 @@ mod tests { fn build_transport_request_uses_default_url_for_global_endpoint() { let operation = CosmosOperation::read_database(DatabaseReference::from_name(test_account(), "mydb")); + let endpoint = + CosmosEndpoint::global(Url::parse("https://test.documents.azure.com:443/").unwrap()); let routing = RoutingDecision { - endpoint: CosmosEndpoint::global( - Url::parse("https://test.documents.azure.com:443/").unwrap(), - ), - selected_url: Url::parse("https://test.documents.azure.com:443/").unwrap(), + selected_url: endpoint.url().clone(), + endpoint_key: endpoint.endpoint_key(), + endpoint, transport_mode: TransportMode::Gateway, }; @@ -838,6 +881,7 @@ mod tests { activity_id: &activity_id, execution_context: ExecutionContext::Initial, deadline: None, + effective_consistency: DefaultConsistencyLevel::Session, resolved_session_token: None, throughput_control: None, }; @@ -888,6 +932,7 @@ mod tests { &retry_state, &location, false, + true, Duration::from_secs(60), ); assert_eq!(routing.endpoint, write_endpoint); @@ -938,6 +983,7 @@ mod tests { &retry_state, &location, false, + true, Duration::from_secs(60), ); assert_eq!(routing.endpoint, default_endpoint); @@ -986,6 +1032,7 @@ mod tests { &retry_state, &location, false, + true, Duration::from_secs(60), ); assert_eq!(routing.endpoint, read_endpoint); @@ -1043,6 +1090,7 @@ mod tests { &stale_retry_state, &location, false, + true, Duration::from_secs(60), ); assert_eq!(first_routing.endpoint, endpoint_a); @@ -1056,6 +1104,7 @@ mod tests { &advanced_state, &location, false, + true, Duration::from_secs(60), ); assert_eq!(second_routing.endpoint, endpoint_b); @@ -1199,6 +1248,7 @@ mod tests { &retry_state, &location, true, + true, Duration::from_secs(60), ); assert_eq!(routing.endpoint, endpoint); @@ -1209,6 +1259,139 @@ mod tests { ); } + #[test] + fn resolve_endpoint_falls_back_to_gateway_when_op_ineligible_for_gateway20() { + let operation = CosmosOperation::read_all_databases(test_account()); + let endpoint = CosmosEndpoint::regional_with_gateway20( + "westus2".into(), + Url::parse("https://test-westus2.documents.azure.com:443/").unwrap(), + Url::parse("https://test-westus2-thin.documents.azure.com:444/").unwrap(), + ); + + let location = LocationSnapshot::for_tests(Arc::new(AccountEndpointState { + generation: 0, + preferred_read_endpoints: vec![endpoint.clone()].into(), + preferred_write_endpoints: vec![endpoint.clone()].into(), + unavailable_endpoints: Default::default(), + multiple_write_locations_enabled: false, + default_endpoint: endpoint.clone(), + })); + + let retry_state = crate::driver::pipeline::components::OperationRetryState::initial( + 0, + false, + Vec::new(), + 3, + 2, + ); + + let routing = super::resolve_endpoint( + &operation, + &retry_state, + &location, + true, + true, + Duration::from_secs(60), + ); + + assert_eq!(routing.transport_mode, TransportMode::Gateway); + assert_eq!(routing.selected_url, *endpoint.url()); + } + + #[test] + fn resolve_endpoint_falls_back_to_gateway_when_account_name_unparseable() { + let operation = CosmosOperation::read_item(ItemReference::from_name( + &test_container(), + PartitionKey::from("pk1"), + "doc1", + )); + let endpoint = CosmosEndpoint::regional_with_gateway20( + "westus2".into(), + Url::parse("https://test-westus2.documents.azure.com:443/").unwrap(), + Url::parse("https://test-westus2-thin.documents.azure.com:444/").unwrap(), + ); + + let location = LocationSnapshot::for_tests(Arc::new(AccountEndpointState { + generation: 0, + preferred_read_endpoints: vec![endpoint.clone()].into(), + preferred_write_endpoints: vec![endpoint.clone()].into(), + unavailable_endpoints: Default::default(), + multiple_write_locations_enabled: false, + default_endpoint: endpoint.clone(), + })); + + let retry_state = crate::driver::pipeline::components::OperationRetryState::initial( + 0, + false, + Vec::new(), + 3, + 2, + ); + + let routing = super::resolve_endpoint( + &operation, + &retry_state, + &location, + true, + false, + Duration::from_secs(60), + ); + + assert_eq!(routing.transport_mode, TransportMode::Gateway); + assert_eq!(routing.selected_url, *endpoint.url()); + } + + #[test] + fn resolve_endpoint_uses_gateway20_authority_for_endpoint_key() { + let operation = CosmosOperation::read_item(ItemReference::from_name( + &test_container(), + PartitionKey::from("pk1"), + "doc1", + )); + let gateway20_url = Url::parse("https://central.gateway20.azure.com:444/").unwrap(); + let endpoint = CosmosEndpoint::regional_with_gateway20( + "centralus".into(), + Url::parse("https://central.documents.azure.com:443/").unwrap(), + gateway20_url.clone(), + ); + + let location = LocationSnapshot::for_tests(Arc::new(AccountEndpointState { + generation: 0, + preferred_read_endpoints: vec![endpoint.clone()].into(), + preferred_write_endpoints: vec![endpoint.clone()].into(), + unavailable_endpoints: Default::default(), + multiple_write_locations_enabled: false, + default_endpoint: endpoint, + })); + + let retry_state = crate::driver::pipeline::components::OperationRetryState::initial( + 0, + false, + Vec::new(), + 3, + 2, + ); + + let routing = super::resolve_endpoint( + &operation, + &retry_state, + &location, + true, + true, + Duration::from_secs(60), + ); + + assert_eq!(routing.transport_mode, TransportMode::Gateway20); + assert_eq!( + routing.selected_url.host_str(), + Some("central.gateway20.azure.com") + ); + assert_eq!( + routing.endpoint_key, + EndpointKey::try_from(&gateway20_url).unwrap() + ); + } + #[test] fn resolve_endpoint_skips_unavailable_region_when_gateway20_is_present() { let operation = CosmosOperation::read_item(ItemReference::from_name( @@ -1257,6 +1440,7 @@ mod tests { &retry_state, &location, true, + true, Duration::from_secs(60), ); assert_eq!(routing.endpoint, fallback_endpoint); @@ -1275,6 +1459,7 @@ mod tests { activity_id: &activity_id, execution_context: ExecutionContext::Initial, deadline: None, + effective_consistency: DefaultConsistencyLevel::Session, resolved_session_token: None, throughput_control: None, }; @@ -1308,6 +1493,7 @@ mod tests { activity_id: &activity_id, execution_context: ExecutionContext::Initial, deadline: None, + effective_consistency: DefaultConsistencyLevel::Session, resolved_session_token: None, throughput_control: None, }; @@ -1344,6 +1530,7 @@ mod tests { deadline: None, resolved_session_token: None, throughput_control: None, + effective_consistency: DefaultConsistencyLevel::Session, }; let request = build_transport_request(&operation, None, &ctx).expect("request should build"); @@ -1390,6 +1577,7 @@ mod tests { deadline: None, resolved_session_token: None, throughput_control: None, + effective_consistency: DefaultConsistencyLevel::Session, }; let request = build_transport_request(&operation, None, &ctx).expect("request should build"); @@ -1426,6 +1614,7 @@ mod tests { activity_id: &activity_id, execution_context: ExecutionContext::Initial, deadline: None, + effective_consistency: DefaultConsistencyLevel::Session, resolved_session_token: None, throughput_control: Some(&snapshot), }; @@ -1469,6 +1658,7 @@ mod tests { activity_id: &activity_id, execution_context: ExecutionContext::Initial, deadline: None, + effective_consistency: DefaultConsistencyLevel::Session, resolved_session_token: None, throughput_control: Some(&snapshot), }; @@ -1513,6 +1703,7 @@ mod tests { activity_id: &activity_id, execution_context: ExecutionContext::Initial, deadline: None, + effective_consistency: DefaultConsistencyLevel::Session, resolved_session_token: None, throughput_control: Some(&snapshot), }; diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/routing/routing_systems.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/routing/routing_systems.rs index 8ab9d704bcb..0abb17ba0d9 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/src/driver/routing/routing_systems.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/routing/routing_systems.rs @@ -69,11 +69,11 @@ pub(crate) fn build_account_endpoint_state( fn build_preferred_endpoints( standard_locations: &[crate::driver::cache::AccountRegion], - thin_client_locations: &[crate::driver::cache::AccountRegion], + gateway20_locations: &[crate::driver::cache::AccountRegion], gateway20_enabled: bool, ) -> Vec { - let thin_client_urls = if gateway20_enabled { - parse_thin_client_locations(thin_client_locations) + let gateway20_urls = if gateway20_enabled { + parse_gateway20_locations(gateway20_locations) } else { HashMap::new() }; @@ -82,7 +82,7 @@ fn build_preferred_endpoints( for region in standard_locations { let url = region.database_account_endpoint.url().clone(); - let endpoint = thin_client_urls + let endpoint = gateway20_urls .get(®ion.name) .cloned() .map(|gateway20_url| { @@ -100,12 +100,12 @@ fn build_preferred_endpoints( endpoints } -fn parse_thin_client_locations( - thin_client_locations: &[crate::driver::cache::AccountRegion], +fn parse_gateway20_locations( + gateway20_locations: &[crate::driver::cache::AccountRegion], ) -> HashMap { let mut urls = HashMap::new(); - for region in thin_client_locations { + for region in gateway20_locations { let url = region.database_account_endpoint.url().clone(); if url.scheme() != "https" { @@ -113,7 +113,7 @@ fn parse_thin_client_locations( region = %region.name, endpoint = %region.database_account_endpoint, scheme = url.scheme(), - "Ignoring non-HTTPS thin-client endpoint URL" + "Ignoring non-HTTPS Gateway 2.0 endpoint URL" ); continue; } @@ -125,7 +125,7 @@ fn parse_thin_client_locations( region = %region.name, existing_url = %existing, new_url = %url, - "Duplicate thin-client region with conflicting URL; keeping first entry" + "Duplicate Gateway 2.0 region with conflicting URL; keeping first entry" ); } }) diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/adaptive_transport.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/adaptive_transport.rs index 567f88c7f14..07a32bd4bac 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/adaptive_transport.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/adaptive_transport.rs @@ -19,7 +19,7 @@ use crate::options::ConnectionPoolOptions; /// `Gateway` is an unsharded HTTP/1.1 transport used when the gateway does not /// support HTTP/2. `ShardedGateway` is a per-endpoint sharded HTTP/2 transport /// used when HTTP/2 has been confirmed via the initialization probe. -/// `ShardedGateway20` is reserved for Gateway 2.0 thin-client requests and +/// `ShardedGateway20` is reserved for Gateway 2.0 requests and /// always uses HTTP/2 prior knowledge. #[derive(Clone)] pub(crate) enum AdaptiveTransport { diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/cosmos_headers.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/cosmos_headers.rs index 734b175a2ac..80b9553d3fd 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/cosmos_headers.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/cosmos_headers.rs @@ -13,7 +13,16 @@ const APPLICATION_JSON: HeaderValue = HeaderValue::from_static("application/json const VERSION: HeaderName = HeaderName::from_static("x-ms-version"); const SDK_SUPPORTED_CAPABILITIES: HeaderName = HeaderName::from_static("x-ms-cosmos-sdk-supportedcapabilities"); -const SUPPORTED_CAPABILITIES_VALUE: &str = "0"; +const PARTITION_MERGE_BIT: u32 = 1; +const IGNORE_UNKNOWN_RNTBD_TOKENS_BIT: u32 = 8; +pub(crate) const SUPPORTED_CAPABILITIES_BITS: u32 = + PARTITION_MERGE_BIT | IGNORE_UNKNOWN_RNTBD_TOKENS_BIT; +const _: () = assert!(SUPPORTED_CAPABILITIES_BITS == 9); +/// String-encoded SDK capabilities bitmask. +/// +/// Derived from `PartitionMerge` (1) | `IgnoreUnknownRntbdTokens` (8), which +/// advertises Gateway 2.0 forward compatibility with unknown RNTBD tokens. +const SUPPORTED_CAPABILITIES_VALUE: &str = "9"; const CACHE_CONTROL: HeaderName = HeaderName::from_static("cache-control"); const NO_CACHE: HeaderValue = HeaderValue::from_static("no-cache"); @@ -40,3 +49,37 @@ pub(crate) fn apply_cosmos_headers(request: &mut HttpRequest, user_agent: &Heade request.headers.insert(CACHE_CONTROL, NO_CACHE.clone()); request.headers.insert(USER_AGENT, user_agent.clone()); } + +#[cfg(test)] +mod tests { + use super::*; + use azure_core::http::{headers::Headers, Method}; + use url::Url; + + #[test] + fn applies_supported_capabilities_bitmask() { + let mut request = HttpRequest { + url: Url::parse("https://example.documents.azure.com/").unwrap(), + method: Method::Get, + headers: Headers::new(), + body: None, + timeout: None, + #[cfg(feature = "fault_injection")] + evaluation_collector: None, + }; + let user_agent = HeaderValue::from_static("test-agent"); + + apply_cosmos_headers(&mut request, &user_agent); + + assert_eq!( + SUPPORTED_CAPABILITIES_VALUE.parse::().unwrap(), + PARTITION_MERGE_BIT | IGNORE_UNKNOWN_RNTBD_TOKENS_BIT + ); + assert_eq!( + request + .headers + .get_optional_str(&SDK_SUPPORTED_CAPABILITIES), + Some(SUPPORTED_CAPABILITIES_VALUE) + ); + } +} diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/gateway20_dispatch.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/gateway20_dispatch.rs new file mode 100644 index 00000000000..28487d07add --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/gateway20_dispatch.rs @@ -0,0 +1,1143 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//! Gateway 2.0 HTTP dispatch helpers. + +use std::sync::atomic::{AtomicU32, Ordering}; + +use azure_core::{ + error::ErrorKind, + http::{ + headers::{HeaderName, HeaderValue, Headers, AUTHORIZATION, USER_AGENT}, + Method, + }, +}; +use uuid::Uuid; + +use crate::{ + constants::{GATEWAY20_RANGE_MAX, GATEWAY20_RANGE_MIN}, + models::{ + cosmos_headers::response_header_names, effective_partition_key::EffectivePartitionKey, + DefaultConsistencyLevel, OperationType, PartitionKey, PartitionKeyDefinition, ResourceType, + }, +}; + +use super::{ + cosmos_headers::SUPPORTED_CAPABILITIES_BITS, + cosmos_transport_client::{HttpRequest, HttpResponse}, + rntbd::{RntbdRequestFrame, RntbdResponse, Token}, + AuthorizationContext, +}; + +const X_MS_ACTIVITY_ID: HeaderName = HeaderName::from_static("x-ms-activity-id"); +const X_MS_DATE: HeaderName = HeaderName::from_static("x-ms-date"); +const X_MS_LSN: HeaderName = HeaderName::from_static("x-ms-lsn"); +const X_MS_GLOBAL_COMMITTED_LSN: HeaderName = HeaderName::from_static("x-ms-global-committed-lsn"); +const X_MS_CONTINUATION: HeaderName = HeaderName::from_static("x-ms-continuation"); +static TRANSPORT_REQUEST_ID: AtomicU32 = AtomicU32::new(0); + +/// Inputs resolved by the operation pipeline before a Gateway 2.0 dispatch. +pub(crate) struct WrapInputs<'a> { + pub(crate) auth_context: &'a AuthorizationContext, + pub(crate) operation_type: OperationType, + pub(crate) resource_type: ResourceType, + pub(crate) partition_key: Option<&'a PartitionKey>, + pub(crate) partition_key_definition: Option<&'a PartitionKeyDefinition>, + pub(crate) effective_consistency: DefaultConsistencyLevel, + pub(crate) account_name: Option<&'a str>, +} + +/// Wraps a signed Cosmos HTTP request into a Gateway 2.0 RNTBD request frame. +pub(crate) fn wrap_request_for_gateway20( + request: &HttpRequest, + inputs: &WrapInputs<'_>, +) -> azure_core::Result { + let authorization = required_header(request, &AUTHORIZATION, "authorization")?; + let date = required_header(request, &X_MS_DATE, "x-ms-date")?; + let activity_id = required_header(request, &X_MS_ACTIVITY_ID, "x-ms-activity-id")?; + let activity_id = Uuid::parse_str(&activity_id) + .map_err(|e| data_conversion_error(format!("x-ms-activity-id is not a valid UUID: {e}")))?; + let account_name = inputs + .account_name + .filter(|value| !value.is_empty()) + .ok_or_else(|| data_conversion_error("Gateway 2.0 dispatch requires an account name"))?; + + let resource_names = parse_resource_names(inputs.auth_context.resource_link.as_str())?; + let has_payload = request.body.as_ref().is_some_and(|body| !body.is_empty()); + + let epk_payload = effective_partition_key_payload(inputs)?; + + let mut metadata = Vec::with_capacity(11); + if let Some(EpkPayload::Point(epk)) = epk_payload.as_ref() { + metadata.push(Token::effective_partition_key(epk.clone())); + } + metadata.push(Token::global_database_account_name(account_name.to_owned())); + metadata.push(Token::database_name(resource_names.database)); + metadata.push(Token::collection_name(resource_names.collection)); + metadata.push(Token::payload_present(has_payload)); + if inputs.resource_type == ResourceType::Document + && inputs.operation_type != OperationType::Create + { + if let Some(document) = resource_names.document { + metadata.push(Token::document_name(document)); + } + } + metadata.push(Token::authorization_token(authorization)); + metadata.push(Token::date(date)); + metadata.push(Token::consistency_level(inputs.effective_consistency)); + metadata.push(Token::transport_request_id(next_transport_request_id())); + metadata.push(Token::sdk_supported_capabilities( + SUPPORTED_CAPABILITIES_BITS, + )); + if let Some(continuation) = request.headers.get_optional_str(&X_MS_CONTINUATION) { + metadata.push(Token::continuation_token(continuation.to_owned())); + } + + let frame = RntbdRequestFrame { + resource_type: inputs.resource_type, + operation_type: inputs.operation_type, + activity_id, + metadata, + body: if has_payload { + request.body.as_ref().map(|body| body.to_vec()) + } else { + None + }, + } + .serialize()?; + + let mut headers = Headers::new(); + if let Some(user_agent) = request.headers.get_optional_str(&USER_AGENT) { + headers.insert(USER_AGENT, HeaderValue::from(user_agent.to_owned())); + } + headers.insert(X_MS_ACTIVITY_ID, HeaderValue::from(activity_id.to_string())); + if let Some(EpkPayload::Range { min, max }) = epk_payload.as_ref() { + headers.insert(GATEWAY20_RANGE_MIN, HeaderValue::from(min.clone())); + headers.insert(GATEWAY20_RANGE_MAX, HeaderValue::from(max.clone())); + } + + Ok(HttpRequest { + url: request.url.clone(), + method: Method::Post, + headers, + body: Some(bytes::Bytes::from(frame)), + timeout: request.timeout, + #[cfg(feature = "fault_injection")] + evaluation_collector: request.evaluation_collector.clone(), + }) +} + +/// Decodes a Gateway 2.0 RNTBD response body into a synthetic HTTP response. +pub(crate) fn unwrap_response_for_gateway20( + response: HttpResponse, +) -> azure_core::Result { + let response = RntbdResponse::deserialize(&response.body)?; + let status = u16::from(response.status.status_code()); + if !(100..=599).contains(&status) { + return Err(data_conversion_error(format!( + "Gateway 2.0 RNTBD response contained invalid HTTP status {status}" + ))); + } + + let mut headers = Headers::new(); + headers.insert( + response_header_names::ACTIVITY_ID, + response.activity_id.to_string(), + ); + if let Some(charge) = response.request_charge { + headers.insert(response_header_names::REQUEST_CHARGE, charge.to_string()); + } + if let Some(token) = response.session_token { + headers.insert(response_header_names::SESSION_TOKEN, token); + } + if let Some(etag) = response.etag { + headers.insert(response_header_names::ETAG, etag); + } + if let Some(continuation) = response.continuation_token { + headers.insert(response_header_names::CONTINUATION, continuation); + } + if let Some(substatus) = response.status.sub_status() { + headers.insert( + response_header_names::SUBSTATUS, + substatus.value().to_string(), + ); + } + if let Some(retry_after_ms) = response.retry_after_ms { + headers.insert("x-ms-retry-after-ms", retry_after_ms.to_string()); + } + if let Some(lsn) = response.lsn.filter(|value| *value != 0) { + let value = lsn.to_string(); + headers.insert(response_header_names::LSN, value.clone()); + headers.insert(X_MS_LSN, value); + } + if let Some(item_lsn) = response.item_lsn.filter(|value| *value != 0) { + headers.insert(response_header_names::ITEM_LSN, item_lsn.to_string()); + } + if let Some(global_committed_lsn) = response.global_committed_lsn.filter(|value| *value != 0) { + headers.insert(X_MS_GLOBAL_COMMITTED_LSN, global_committed_lsn.to_string()); + } + if let Some(owner_full_name) = response.owner_full_name { + headers.insert(response_header_names::OWNER_FULL_NAME, owner_full_name); + } + + Ok(HttpResponse { + status, + headers, + body: response.body, + }) +} + +fn required_header( + request: &HttpRequest, + header_name: &HeaderName, + display_name: &'static str, +) -> azure_core::Result { + request + .headers + .get_optional_str(header_name) + .map(str::to_owned) + .ok_or_else(|| data_conversion_error(format!("missing required {display_name} header"))) +} + +fn next_transport_request_id() -> u32 { + // AcqRel ensures the increment is globally visible. Relaxed would also produce + // unique values across threads (fetch_add is atomic regardless of ordering), + // but AcqRel is preferred here for diagnostic clarity in concurrent traces. + TRANSPORT_REQUEST_ID.fetch_add(1, Ordering::AcqRel) +} + +/// Wire-form payload derived from the partition key + definition for a +/// Gateway 2.0 dispatch. +/// +/// `Point` represents a single-logical-partition operation and is emitted as +/// the `EffectivePartitionKey` RNTBD metadata token (binary EPK bytes). +/// `Range` represents an EPK range — either a hierarchical-PK prefix that +/// fans out across multiple physical partitions, or a feed/cross-partition +/// operation scoped to a sub-range — and is emitted as the +/// `x-ms-thinclient-range-min` / `-max` outer HTTP headers carrying the +/// canonical, un-padded hex EPK string per `GATEWAY_20_SPEC §"Range header +/// wire format"`. +/// +/// The two arms are mutually exclusive; the proxy must never see both an +/// EPK token and EPK range headers on the same request. +enum EpkPayload { + Point(Vec), + Range { min: String, max: String }, +} + +fn effective_partition_key_payload( + inputs: &WrapInputs<'_>, +) -> azure_core::Result> { + let (Some(partition_key), Some(partition_key_definition)) = + (inputs.partition_key, inputs.partition_key_definition) + else { + return Ok(None); + }; + + if partition_key.is_empty() { + return Ok(None); + } + + let range = + EffectivePartitionKey::compute_range(partition_key.values(), partition_key_definition) + .map_err(|err| { + data_conversion_error(format!("Gateway 2.0 EPK range computation failed: {err}")) + })?; + + if range.start == range.end { + let bytes = hex_to_bytes(range.start.as_str())?; + Ok(Some(EpkPayload::Point(bytes))) + } else { + Ok(Some(EpkPayload::Range { + min: range.start.as_str().to_owned(), + max: range.end.as_str().to_owned(), + })) + } +} + +fn hex_to_bytes(value: &str) -> azure_core::Result> { + if value.len() & 1 != 0 { + return Err(data_conversion_error(format!( + "effective partition key hex length {} is not even", + value.len() + ))); + } + + let mut bytes = Vec::with_capacity(value.len() / 2); + for chunk in value.as_bytes().chunks_exact(2) { + let hi = hex_digit(chunk[0])?; + let lo = hex_digit(chunk[1])?; + bytes.push((hi << 4) | lo); + } + Ok(bytes) +} + +fn hex_digit(value: u8) -> azure_core::Result { + match value { + b'0'..=b'9' => Ok(value - b'0'), + b'a'..=b'f' => Ok(value - b'a' + 10), + b'A'..=b'F' => Ok(value - b'A' + 10), + _ => Err(data_conversion_error(format!( + "invalid effective partition key hex digit 0x{value:02X}" + ))), + } +} + +struct ResourceNames { + database: String, + collection: String, + document: Option, +} + +fn parse_resource_names(resource_link: &str) -> azure_core::Result { + let mut database = None; + let mut collection = None; + let mut document = None; + let mut segments = resource_link + .trim_matches('/') + .split('/') + .filter(|segment| !segment.is_empty()); + + while let Some(kind) = segments.next() { + let Some(name) = segments.next() else { + break; + }; + match kind { + "dbs" => database = Some(name.to_owned()), + "colls" => collection = Some(name.to_owned()), + "docs" => document = Some(name.to_owned()), + _ => {} + } + } + + let database = database.filter(|value| !value.is_empty()).ok_or_else(|| { + data_conversion_error("Gateway 2.0 resource link is missing database name") + })?; + let collection = collection + .filter(|value| !value.is_empty()) + .ok_or_else(|| { + data_conversion_error("Gateway 2.0 resource link is missing collection name") + })?; + + Ok(ResourceNames { + database, + collection, + document, + }) +} + +fn data_conversion_error(message: impl Into) -> azure_core::Error { + azure_core::Error::with_message(ErrorKind::DataConversion, message.into()) +} + +#[cfg(test)] +mod tests { + use std::{borrow::Cow, collections::HashMap}; + + use azure_core::http::headers::{ACCEPT, CONTENT_TYPE}; + + use super::*; + use crate::models::{PartitionKeyKind, PartitionKeyValue, PartitionKeyVersion}; + + const ACTIVITY_ID: &str = "00112233-4455-6677-8899-aabbccddeeff"; + + #[derive(Clone, Debug, PartialEq)] + enum ParsedTokenValue { + Byte(u8), + ULong(u32), + LongLong(i64), + Double(f64), + SmallString(String), + String(String), + Bytes(Vec), + } + + #[derive(Debug)] + struct ParsedRequest { + resource_type: u16, + operation_type: u16, + activity_id: Uuid, + tokens: HashMap, + body: Option>, + } + + fn signed_request(body: Option<&[u8]>) -> HttpRequest { + let mut headers = Headers::new(); + headers.insert(AUTHORIZATION, "auth-token"); + headers.insert(X_MS_DATE, "Wed, 21 Oct 2015 07:28:00 GMT"); + headers.insert(X_MS_ACTIVITY_ID, ACTIVITY_ID); + headers.insert(USER_AGENT, "test-agent"); + headers.insert(CONTENT_TYPE, "application/json"); + headers.insert(ACCEPT, "application/json"); + + HttpRequest { + url: "https://account-thin.documents.azure.com:444/dbs/db1/colls/coll1/docs/doc1" + .parse() + .unwrap(), + method: Method::Get, + headers, + body: body.map(bytes::Bytes::copy_from_slice), + timeout: None, + #[cfg(feature = "fault_injection")] + evaluation_collector: None, + } + } + + fn wrap_inputs<'a>( + auth_context: &'a AuthorizationContext, + operation_type: OperationType, + partition_key: Option<&'a PartitionKey>, + partition_key_definition: Option<&'a PartitionKeyDefinition>, + ) -> WrapInputs<'a> { + WrapInputs { + auth_context, + operation_type, + resource_type: ResourceType::Document, + partition_key, + partition_key_definition, + effective_consistency: DefaultConsistencyLevel::Session, + account_name: Some("account"), + } + } + + fn parse_wrapped_request(request: &HttpRequest, token_count: usize) -> ParsedRequest { + let mut src = request.body.as_ref().unwrap().as_ref(); + let total_len = take_u32(&mut src) as usize; + assert_eq!(total_len, request.body.as_ref().unwrap().len()); + let resource_type = take_u16(&mut src); + let operation_type = take_u16(&mut src); + let activity_id = take_uuid(&mut src); + + let mut tokens = HashMap::new(); + for _ in 0..token_count { + let id = take_u16(&mut src); + let token_type = take_u8(&mut src); + let value = parse_token_value(token_type, &mut src); + tokens.insert(id, value); + } + + let body = if src.is_empty() { + None + } else { + let body_len = take_u32(&mut src) as usize; + assert_eq!(src.len(), body_len); + Some(src.to_vec()) + }; + + ParsedRequest { + resource_type, + operation_type, + activity_id, + tokens, + body, + } + } + + fn parse_token_value(token_type: u8, src: &mut &[u8]) -> ParsedTokenValue { + match token_type { + 0x00 => ParsedTokenValue::Byte(take_u8(src)), + 0x02 => ParsedTokenValue::ULong(take_u32(src)), + 0x05 => ParsedTokenValue::LongLong(take_i64(src)), + 0x07 => { + let len = take_u8(src) as usize; + ParsedTokenValue::SmallString(take_string(src, len)) + } + 0x08 => { + let len = take_u16(src) as usize; + ParsedTokenValue::String(take_string(src, len)) + } + 0x0B => { + let len = take_u16(src) as usize; + ParsedTokenValue::Bytes(take_bytes(src, len).to_vec()) + } + 0x0E => ParsedTokenValue::Double(f64::from_le_bytes(take_array(src))), + other => panic!("unexpected token type 0x{other:02X}"), + } + } + + #[test] + fn wrap_builds_required_request_tokens_for_read() { + let request = signed_request(None); + let auth_context = AuthorizationContext::new( + Method::Get, + ResourceType::Document, + "dbs/db1/colls/coll1/docs/doc1", + ); + + let wrapped = wrap_request_for_gateway20( + &request, + &wrap_inputs(&auth_context, OperationType::Read, None, None), + ) + .unwrap(); + let parsed = parse_wrapped_request(&wrapped, 10); + + assert_eq!(wrapped.method, Method::Post); + assert_eq!(parsed.resource_type, 0x0003); + assert_eq!(parsed.operation_type, 0x0003); + assert_eq!(parsed.activity_id, Uuid::parse_str(ACTIVITY_ID).unwrap()); + assert_eq!( + parsed.tokens[&0x0001], + ParsedTokenValue::String("auth-token".into()) + ); + assert_eq!(parsed.tokens[&0x0002], ParsedTokenValue::Byte(0)); + assert_eq!( + parsed.tokens[&0x0003], + ParsedTokenValue::SmallString("Wed, 21 Oct 2015 07:28:00 GMT".into()) + ); + assert_eq!(parsed.tokens[&0x0010], ParsedTokenValue::Byte(0x02)); + assert_eq!( + parsed.tokens[&0x0015], + ParsedTokenValue::String("db1".into()) + ); + assert_eq!( + parsed.tokens[&0x0016], + ParsedTokenValue::String("coll1".into()) + ); + assert_eq!( + parsed.tokens[&0x0017], + ParsedTokenValue::String("doc1".into()) + ); + assert_eq!( + parsed.tokens[&0x004D], + ParsedTokenValue::ULong(parsed_transport_id(&parsed)) + ); + assert_eq!( + parsed.tokens[&0x00A2], + ParsedTokenValue::ULong(SUPPORTED_CAPABILITIES_BITS) + ); + assert_eq!( + parsed.tokens[&0x00CE], + ParsedTokenValue::String("account".into()) + ); + } + + #[test] + fn wrap_preserves_payload_and_sets_payload_present() { + let request = signed_request(Some(br#"{"id":"doc1"}"#)); + let auth_context = + AuthorizationContext::new(Method::Post, ResourceType::Document, "dbs/db1/colls/coll1"); + + let wrapped = wrap_request_for_gateway20( + &request, + &wrap_inputs(&auth_context, OperationType::Create, None, None), + ) + .unwrap(); + let parsed = parse_wrapped_request(&wrapped, 9); + + assert_eq!(parsed.tokens[&0x0002], ParsedTokenValue::Byte(1)); + assert_eq!(parsed.body, Some(br#"{"id":"doc1"}"#.to_vec())); + } + + #[test] + fn wrap_omits_document_name_for_create() { + let request = signed_request(Some(b"{}")); + let auth_context = + AuthorizationContext::new(Method::Post, ResourceType::Document, "dbs/db1/colls/coll1"); + + let wrapped = wrap_request_for_gateway20( + &request, + &wrap_inputs(&auth_context, OperationType::Create, None, None), + ) + .unwrap(); + let parsed = parse_wrapped_request(&wrapped, 9); + + assert!(!parsed.tokens.contains_key(&0x0017)); + } + + #[test] + fn wrap_uses_resolved_consistency_token() { + let request = signed_request(None); + let auth_context = AuthorizationContext::new( + Method::Get, + ResourceType::Document, + "dbs/db1/colls/coll1/docs/doc1", + ); + let mut inputs = wrap_inputs(&auth_context, OperationType::Read, None, None); + inputs.effective_consistency = DefaultConsistencyLevel::Eventual; + + let wrapped = wrap_request_for_gateway20(&request, &inputs).unwrap(); + let parsed = parse_wrapped_request(&wrapped, 10); + + assert_eq!(parsed.tokens[&0x0010], ParsedTokenValue::Byte(0x03)); + } + + #[test] + fn wrap_computes_effective_partition_key_bytes() { + let request = signed_request(None); + let auth_context = AuthorizationContext::new( + Method::Get, + ResourceType::Document, + "dbs/db1/colls/coll1/docs/doc1", + ); + let partition_key = PartitionKey::from("tenant1"); + let partition_key_definition = PartitionKeyDefinition::new(vec![Cow::from("/tenantId")]); + let expected = hex_to_bytes( + EffectivePartitionKey::compute( + partition_key.values(), + PartitionKeyKind::Hash, + PartitionKeyVersion::V2, + ) + .as_str(), + ) + .unwrap(); + + let wrapped = wrap_request_for_gateway20( + &request, + &wrap_inputs( + &auth_context, + OperationType::Read, + Some(&partition_key), + Some(&partition_key_definition), + ), + ) + .unwrap(); + let parsed = parse_wrapped_request(&wrapped, 11); + + assert_eq!(parsed.tokens[&0x005A], ParsedTokenValue::Bytes(expected)); + } + + /// HPK partial-PK (prefix on a MultiHash container) is dispatched as an + /// EPK *range* via the outer `x-ms-thinclient-range-min`/`-max` HTTP + /// headers, not as an `EffectivePartitionKey` RNTBD token. The two + /// emission paths must be mutually exclusive. + #[test] + fn wrap_emits_range_headers_for_hpk_prefix_partition_key() { + let request = signed_request(None); + let auth_context = + AuthorizationContext::new(Method::Get, ResourceType::Document, "dbs/db1/colls/coll1"); + let partition_key = + PartitionKey::from(vec![PartitionKeyValue::from("tenant1".to_string())]); + let partition_key_definition = + PartitionKeyDefinition::from(("/tenantId", "/userId", "/sessionId")); + let expected_range = + EffectivePartitionKey::compute_range(partition_key.values(), &partition_key_definition) + .unwrap(); + assert_ne!( + expected_range.start, expected_range.end, + "HPK prefix must produce a non-point range — sanity check" + ); + + let wrapped = wrap_request_for_gateway20( + &request, + &wrap_inputs( + &auth_context, + OperationType::Query, + Some(&partition_key), + Some(&partition_key_definition), + ), + ) + .unwrap(); + + // Range headers on the outer HTTP request, carrying canonical un-padded hex. + assert_eq!( + wrapped.headers.get_optional_str(&GATEWAY20_RANGE_MIN), + Some(expected_range.start.as_str()) + ); + assert_eq!( + wrapped.headers.get_optional_str(&GATEWAY20_RANGE_MAX), + Some(expected_range.end.as_str()) + ); + + // No EPK token in the inner RNTBD frame for the range path. + // Token layout: 9 base tokens (account, db, coll, payload_present, + // auth, date, consistency, transport_request_id, capabilities) — no + // document_name (resource link omits /docs/...) and no EPK token. + let parsed = parse_wrapped_request(&wrapped, 9); + assert!( + !parsed.tokens.contains_key(&0x005A), + "EffectivePartitionKey token must not be emitted alongside range headers" + ); + } + + /// Full HPK key (component count == definition path count) collapses to a + /// point op: emit the EPK token, no range headers. + #[test] + fn wrap_emits_token_only_for_full_hpk_partition_key() { + let request = signed_request(None); + let auth_context = AuthorizationContext::new( + Method::Get, + ResourceType::Document, + "dbs/db1/colls/coll1/docs/doc1", + ); + let partition_key = PartitionKey::from(vec![ + PartitionKeyValue::from("tenant1".to_string()), + PartitionKeyValue::from("user1".to_string()), + PartitionKeyValue::from("session1".to_string()), + ]); + let partition_key_definition = + PartitionKeyDefinition::from(("/tenantId", "/userId", "/sessionId")); + + let wrapped = wrap_request_for_gateway20( + &request, + &wrap_inputs( + &auth_context, + OperationType::Read, + Some(&partition_key), + Some(&partition_key_definition), + ), + ) + .unwrap(); + + // Range headers must NOT be present on the point path. + assert!(wrapped + .headers + .get_optional_str(&GATEWAY20_RANGE_MIN) + .is_none()); + assert!(wrapped + .headers + .get_optional_str(&GATEWAY20_RANGE_MAX) + .is_none()); + + // EPK token present in the inner RNTBD frame. + let parsed = parse_wrapped_request(&wrapped, 11); + assert!( + parsed.tokens.contains_key(&0x005A), + "EffectivePartitionKey token must be emitted for full HPK partition key" + ); + } + + /// `compute_range` error cases (e.g., more PK components supplied than the + /// container's definition declares) must surface as a wrap error, mapped + /// to `BadRequest` upstream — never silently emit broken EPK metadata. + #[test] + fn wrap_rejects_partition_key_with_too_many_components() { + let request = signed_request(None); + let auth_context = AuthorizationContext::new( + Method::Get, + ResourceType::Document, + "dbs/db1/colls/coll1/docs/doc1", + ); + let partition_key = PartitionKey::from(vec![ + PartitionKeyValue::from("tenant1".to_string()), + PartitionKeyValue::from("extra".to_string()), + ]); + let partition_key_definition = PartitionKeyDefinition::from("/tenantId"); + + let error = wrap_request_for_gateway20( + &request, + &wrap_inputs( + &auth_context, + OperationType::Read, + Some(&partition_key), + Some(&partition_key_definition), + ), + ) + .unwrap_err(); + + assert_eq!(error.kind(), &ErrorKind::DataConversion); + } + + #[test] + fn wrap_propagates_continuation_token_into_rntbd_metadata() { + let mut request = signed_request(None); + request.headers.insert(X_MS_CONTINUATION, "page-token-1"); + let auth_context = + AuthorizationContext::new(Method::Get, ResourceType::Document, "dbs/db1/colls/coll1"); + + let wrapped = wrap_request_for_gateway20( + &request, + &WrapInputs { + auth_context: &auth_context, + operation_type: OperationType::Query, + resource_type: ResourceType::Document, + partition_key: None, + partition_key_definition: None, + effective_consistency: DefaultConsistencyLevel::Session, + account_name: Some("account"), + }, + ) + .unwrap(); + let parsed = parse_wrapped_request(&wrapped, 10); + + assert_eq!( + parsed.tokens[&0x0006], + ParsedTokenValue::String("page-token-1".into()), + "continuation token should be encoded as string token 0x0006", + ); + assert!( + wrapped + .headers + .get_optional_str(&X_MS_CONTINUATION) + .is_none(), + "x-ms-continuation header should not be forwarded on the outer HTTP request", + ); + } + + #[test] + fn wrap_omits_continuation_token_when_header_absent() { + let request = signed_request(None); + let auth_context = + AuthorizationContext::new(Method::Get, ResourceType::Document, "dbs/db1/colls/coll1"); + + let wrapped = wrap_request_for_gateway20( + &request, + &WrapInputs { + auth_context: &auth_context, + operation_type: OperationType::Query, + resource_type: ResourceType::Document, + partition_key: None, + partition_key_definition: None, + effective_consistency: DefaultConsistencyLevel::Session, + account_name: Some("account"), + }, + ) + .unwrap(); + let parsed = parse_wrapped_request(&wrapped, 9); + + assert!( + !parsed.tokens.contains_key(&0x0006), + "continuation token should be absent when no x-ms-continuation header is present", + ); + } + + #[test] + fn wrap_emits_empty_continuation_token_when_header_value_empty() { + // Symmetry with .NET (`ThinClientStoreClient.PrepareRequestForProxyAsync`), + // Java (`RntbdRequestHeader.ContinuationToken` is *not* in + // `thinClientProxyExcludedSet`), and the unwrap side which forwards + // empty continuation strings verbatim. Continuation is opaque on the + // wire — the wrap path does not infer intent from emptiness. + let mut request = signed_request(None); + request.headers.insert(X_MS_CONTINUATION, ""); + let auth_context = + AuthorizationContext::new(Method::Get, ResourceType::Document, "dbs/db1/colls/coll1"); + + let wrapped = wrap_request_for_gateway20( + &request, + &WrapInputs { + auth_context: &auth_context, + operation_type: OperationType::Query, + resource_type: ResourceType::Document, + partition_key: None, + partition_key_definition: None, + effective_consistency: DefaultConsistencyLevel::Session, + account_name: Some("account"), + }, + ) + .unwrap(); + let parsed = parse_wrapped_request(&wrapped, 10); + + assert_eq!( + parsed.tokens[&0x0006], + ParsedTokenValue::String(String::new()), + "empty continuation header should be emitted as a zero-length string token", + ); + } + + #[test] + fn wrap_only_keeps_user_agent_and_activity_id_headers() { + let request = signed_request(None); + let auth_context = AuthorizationContext::new( + Method::Get, + ResourceType::Document, + "dbs/db1/colls/coll1/docs/doc1", + ); + + let wrapped = wrap_request_for_gateway20( + &request, + &wrap_inputs(&auth_context, OperationType::Read, None, None), + ) + .unwrap(); + + assert_eq!( + wrapped.headers.get_optional_str(&USER_AGENT), + Some("test-agent") + ); + assert_eq!( + wrapped.headers.get_optional_str(&X_MS_ACTIVITY_ID), + Some(ACTIVITY_ID) + ); + assert!(wrapped.headers.get_optional_str(&AUTHORIZATION).is_none()); + assert!(wrapped.headers.get_optional_str(&X_MS_DATE).is_none()); + assert!(wrapped.headers.get_optional_str(&CONTENT_TYPE).is_none()); + assert!(wrapped.headers.get_optional_str(&ACCEPT).is_none()); + } + + #[test] + fn wrap_rejects_missing_authorization_header() { + let mut request = signed_request(None); + request.headers.remove(AUTHORIZATION); + let auth_context = AuthorizationContext::new( + Method::Get, + ResourceType::Document, + "dbs/db1/colls/coll1/docs/doc1", + ); + + let error = wrap_request_for_gateway20( + &request, + &wrap_inputs(&auth_context, OperationType::Read, None, None), + ) + .unwrap_err(); + + assert_eq!(error.kind(), &ErrorKind::DataConversion); + } + + #[test] + fn wrap_rejects_missing_date_header() { + let mut request = signed_request(None); + request.headers.remove(X_MS_DATE); + let auth_context = AuthorizationContext::new( + Method::Get, + ResourceType::Document, + "dbs/db1/colls/coll1/docs/doc1", + ); + + let error = wrap_request_for_gateway20( + &request, + &wrap_inputs(&auth_context, OperationType::Read, None, None), + ) + .unwrap_err(); + + assert_eq!(error.kind(), &ErrorKind::DataConversion); + } + + #[test] + fn wrap_rejects_invalid_activity_id() { + let mut request = signed_request(None); + request.headers.insert(X_MS_ACTIVITY_ID, "not-a-guid"); + let auth_context = AuthorizationContext::new( + Method::Get, + ResourceType::Document, + "dbs/db1/colls/coll1/docs/doc1", + ); + + let error = wrap_request_for_gateway20( + &request, + &wrap_inputs(&auth_context, OperationType::Read, None, None), + ) + .unwrap_err(); + + assert_eq!(error.kind(), &ErrorKind::DataConversion); + } + + #[test] + fn unwrap_maps_response_status_headers_and_body() { + let activity_id = Uuid::parse_str(ACTIVITY_ID).unwrap(); + let response = HttpResponse { + status: 200, + headers: Headers::new(), + body: response_frame( + 404, + activity_id, + |tokens| { + write_u32_token(tokens, 0x001C, 1002); + write_double_token(tokens, 0x0015, 3.5); + write_string_token(tokens, 0x003E, "1:2#3"); + write_string_token(tokens, 0x0004, "\"etag\""); + write_string_token(tokens, 0x0003, "continuation"); + write_i64_token(tokens, 0x0013, 42); + write_i64_token(tokens, 0x0032, 43); + write_i64_token(tokens, 0x0029, 44); + write_string_token(tokens, 0x0017, "dbs/db1/colls/coll1/docs/doc1"); + }, + b"{}", + ), + }; + + let unwrapped = unwrap_response_for_gateway20(response).unwrap(); + + assert_eq!(unwrapped.status, 404); + assert_eq!(unwrapped.body, b"{}".to_vec()); + assert_eq!( + unwrapped.headers.get_optional_str(&X_MS_ACTIVITY_ID), + Some(ACTIVITY_ID) + ); + assert_eq!( + unwrapped + .headers + .get_optional_str(&HeaderName::from_static("x-ms-substatus")), + Some("1002") + ); + assert_eq!( + unwrapped + .headers + .get_optional_str(&HeaderName::from_static("x-ms-request-charge")), + Some("3.5") + ); + assert_eq!( + unwrapped + .headers + .get_optional_str(&HeaderName::from_static("x-ms-session-token")), + Some("1:2#3") + ); + assert_eq!( + unwrapped + .headers + .get_optional_str(&HeaderName::from_static("etag")), + Some("\"etag\"") + ); + assert_eq!( + unwrapped + .headers + .get_optional_str(&HeaderName::from_static("x-ms-continuation")), + Some("continuation") + ); + assert_eq!( + unwrapped + .headers + .get_optional_str(&HeaderName::from_static("lsn")), + Some("42") + ); + assert_eq!(unwrapped.headers.get_optional_str(&X_MS_LSN), Some("42")); + assert_eq!( + unwrapped + .headers + .get_optional_str(&HeaderName::from_static("x-ms-item-lsn")), + Some("43") + ); + assert_eq!( + unwrapped + .headers + .get_optional_str(&X_MS_GLOBAL_COMMITTED_LSN), + Some("44") + ); + } + + #[test] + fn unwrap_preserves_retry_after_for_throttle() { + let response = HttpResponse { + status: 200, + headers: Headers::new(), + body: response_frame( + 429, + Uuid::parse_str(ACTIVITY_ID).unwrap(), + |tokens| write_u32_token(tokens, 0x000C, 125), + b"", + ), + }; + + let unwrapped = unwrap_response_for_gateway20(response).unwrap(); + + assert_eq!(unwrapped.status, 429); + assert_eq!( + unwrapped + .headers + .get_optional_str(&HeaderName::from_static("x-ms-retry-after-ms")), + Some("125") + ); + } + + #[test] + fn unwrap_rejects_malformed_rntbd_body() { + let response = HttpResponse { + status: 200, + headers: Headers::new(), + body: vec![1, 2, 3], + }; + + let error = unwrap_response_for_gateway20(response).unwrap_err(); + + assert_eq!(error.kind(), &ErrorKind::DataConversion); + } + + #[test] + fn unwrap_rejects_out_of_range_inner_status() { + let response = HttpResponse { + status: 200, + headers: Headers::new(), + body: response_frame(70_000, Uuid::parse_str(ACTIVITY_ID).unwrap(), |_| {}, b""), + }; + + let error = unwrap_response_for_gateway20(response).unwrap_err(); + + assert_eq!(error.kind(), &ErrorKind::DataConversion); + } + + fn parsed_transport_id(parsed: &ParsedRequest) -> u32 { + match parsed.tokens[&0x004D] { + ParsedTokenValue::ULong(value) => value, + _ => unreachable!(), + } + } + + fn response_frame( + status: u32, + activity_id: Uuid, + write_tokens: impl FnOnce(&mut Vec), + body: &[u8], + ) -> Vec { + let mut bytes = Vec::new(); + bytes.extend_from_slice(&0_u32.to_le_bytes()); + bytes.extend_from_slice(&status.to_le_bytes()); + write_uuid(&mut bytes, activity_id); + write_tokens(&mut bytes); + bytes.extend_from_slice(body); + let total_len = u32::try_from(bytes.len()).unwrap(); + bytes[0..4].copy_from_slice(&total_len.to_le_bytes()); + bytes + } + + fn write_string_token(bytes: &mut Vec, id: u16, value: &str) { + bytes.extend_from_slice(&id.to_le_bytes()); + bytes.push(0x08); + bytes.extend_from_slice(&(value.len() as u16).to_le_bytes()); + bytes.extend_from_slice(value.as_bytes()); + } + + fn write_u32_token(bytes: &mut Vec, id: u16, value: u32) { + bytes.extend_from_slice(&id.to_le_bytes()); + bytes.push(0x02); + bytes.extend_from_slice(&value.to_le_bytes()); + } + + fn write_i64_token(bytes: &mut Vec, id: u16, value: i64) { + bytes.extend_from_slice(&id.to_le_bytes()); + bytes.push(0x05); + bytes.extend_from_slice(&value.to_le_bytes()); + } + + fn write_double_token(bytes: &mut Vec, id: u16, value: f64) { + bytes.extend_from_slice(&id.to_le_bytes()); + bytes.push(0x0E); + bytes.extend_from_slice(&value.to_le_bytes()); + } + + fn write_uuid(bytes: &mut Vec, value: Uuid) { + let value = value.as_u128(); + let msb = (value >> 64) as u64; + let lsb = value as u64; + bytes.extend_from_slice(&msb.to_le_bytes()); + bytes.extend_from_slice(&lsb.to_le_bytes()); + } + + fn take_u8(src: &mut &[u8]) -> u8 { + let value = src[0]; + *src = &src[1..]; + value + } + + fn take_u16(src: &mut &[u8]) -> u16 { + u16::from_le_bytes(take_array(src)) + } + + fn take_u32(src: &mut &[u8]) -> u32 { + u32::from_le_bytes(take_array(src)) + } + + fn take_i64(src: &mut &[u8]) -> i64 { + i64::from_le_bytes(take_array(src)) + } + + fn take_uuid(src: &mut &[u8]) -> Uuid { + let msb = u64::from_le_bytes(take_array(src)); + let lsb = u64::from_le_bytes(take_array(src)); + Uuid::from_u128(((msb as u128) << 64) | lsb as u128) + } + + fn take_string(src: &mut &[u8], len: usize) -> String { + String::from_utf8(take_bytes(src, len).to_vec()).unwrap() + } + + fn take_bytes<'a>(src: &mut &'a [u8], len: usize) -> &'a [u8] { + let (head, tail) = src.split_at(len); + *src = tail; + head + } + + fn take_array(src: &mut &[u8]) -> [u8; N] { + let bytes = take_bytes(src, N); + let mut out = [0; N]; + out.copy_from_slice(bytes); + out + } +} diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/gateway20_eligibility.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/gateway20_eligibility.rs new file mode 100644 index 00000000000..508dc79d6fa --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/gateway20_eligibility.rs @@ -0,0 +1,141 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//! Gateway 2.0 operation eligibility filter. + +use crate::models::{OperationType, ResourceType}; + +/// Returns `true` when the resource and operation pair is eligible for Gateway 2.0. +/// +/// Only `ResourceType::Document` is currently eligible, matching Java's +/// `ThinClientStoreModel`. Stored-procedure execution is explicitly out of +/// scope for Rust SDK GA; every non-Document resource type falls back to +/// standard Gateway via the eligibility-fallback path. +/// +/// `OperationType::Patch` is not currently a variant on the Rust enum and is +/// therefore not handled here. When the variant is added in a future slice, +/// this match must be updated. +// Slice 3 wires this helper into routing. +#[allow(dead_code)] +pub(crate) fn is_operation_supported_by_gateway20( + resource_type: ResourceType, + operation_type: OperationType, +) -> bool { + // Both arms of this match are intentionally exhaustive (no wildcard `_` arm) so + // that adding a new variant to either enum is a compile-time error, forcing an + // explicit eligibility decision rather than a silent fail-closed default. + match resource_type { + ResourceType::Document => match operation_type { + OperationType::Create + | OperationType::Read + | OperationType::Replace + | OperationType::Upsert + | OperationType::Delete + | OperationType::Query + | OperationType::SqlQuery + | OperationType::QueryPlan + | OperationType::ReadFeed + | OperationType::Batch => true, + OperationType::Head | OperationType::HeadFeed | OperationType::Execute => false, + }, + ResourceType::DatabaseAccount + | ResourceType::Database + | ResourceType::DocumentCollection + | ResourceType::StoredProcedure + | ResourceType::Trigger + | ResourceType::UserDefinedFunction + | ResourceType::PartitionKeyRange + | ResourceType::Offer => false, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn all_resource_types() -> [ResourceType; 9] { + [ + ResourceType::DatabaseAccount, + ResourceType::Database, + ResourceType::DocumentCollection, + ResourceType::Document, + ResourceType::StoredProcedure, + ResourceType::Trigger, + ResourceType::UserDefinedFunction, + ResourceType::PartitionKeyRange, + ResourceType::Offer, + ] + } + + fn all_operation_types() -> [OperationType; 13] { + [ + OperationType::Create, + OperationType::Read, + OperationType::ReadFeed, + OperationType::Replace, + OperationType::Delete, + OperationType::Upsert, + OperationType::Query, + OperationType::SqlQuery, + OperationType::QueryPlan, + OperationType::Batch, + OperationType::Head, + OperationType::HeadFeed, + OperationType::Execute, + ] + } + + fn expected_gateway20_eligibility( + resource_type: ResourceType, + operation_type: OperationType, + ) -> bool { + match resource_type { + ResourceType::Document => match operation_type { + OperationType::Create + | OperationType::Read + | OperationType::Replace + | OperationType::Upsert + | OperationType::Delete + | OperationType::Query + | OperationType::SqlQuery + | OperationType::QueryPlan + | OperationType::ReadFeed + | OperationType::Batch => true, + OperationType::Head | OperationType::HeadFeed | OperationType::Execute => false, + }, + ResourceType::DatabaseAccount + | ResourceType::Database + | ResourceType::DocumentCollection + | ResourceType::StoredProcedure + | ResourceType::Trigger + | ResourceType::UserDefinedFunction + | ResourceType::PartitionKeyRange + | ResourceType::Offer => false, + } + } + + #[test] + fn gateway20_eligibility_matrix_is_exhaustive() { + for resource_type in all_resource_types() { + for operation_type in all_operation_types() { + assert_eq!( + is_operation_supported_by_gateway20(resource_type, operation_type), + expected_gateway20_eligibility(resource_type, operation_type), + "unexpected Gateway 2.0 eligibility for {resource_type:?} {operation_type:?}" + ); + } + } + } + + #[test] + fn stored_procedure_execution_is_explicitly_ineligible() { + assert!(!is_operation_supported_by_gateway20( + ResourceType::StoredProcedure, + OperationType::Execute + )); + assert!(!is_operation_supported_by_gateway20( + ResourceType::Document, + OperationType::Execute + )); + } +} diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/http_client_factory.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/http_client_factory.rs index 40f893dc481..30b364a1b23 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/http_client_factory.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/http_client_factory.rs @@ -7,7 +7,7 @@ use std::{fmt, sync::Arc}; use super::cosmos_transport_client::TransportClient; -use crate::diagnostics::TransportHttpVersion; +use crate::diagnostics::{TransportHttpVersion, TransportKind}; use crate::options::ConnectionPoolOptions; /// HTTP protocol policy required by a transport. @@ -26,6 +26,13 @@ pub struct HttpClientConfig { pub(crate) request_timeout: std::time::Duration, pub(crate) allow_invalid_cert: bool, pub(crate) http2_keep_alive_while_idle: bool, + /// The transport kind this HTTP client serves, when it is bound to a + /// dataplane transport. Metadata clients (account discovery, etc.) leave + /// this `None` because they are not gateway/Gateway-2.0-specific. + /// + /// This is consumed by the fault-injection layer so rules can scope + /// themselves to a specific transport (`with_transport_kind`). + pub(crate) transport_kind: Option, } impl HttpClientConfig { @@ -42,6 +49,7 @@ impl HttpClientConfig { request_timeout: connection_pool.max_metadata_request_timeout(), allow_invalid_cert: false, http2_keep_alive_while_idle: negotiated_version.is_http2(), + transport_kind: None, } } @@ -58,6 +66,7 @@ impl HttpClientConfig { request_timeout: connection_pool.max_dataplane_request_timeout(), allow_invalid_cert: false, http2_keep_alive_while_idle: negotiated_version.is_http2(), + transport_kind: Some(TransportKind::Gateway), } } @@ -68,6 +77,7 @@ impl HttpClientConfig { request_timeout: connection_pool.max_dataplane_request_timeout(), allow_invalid_cert: false, http2_keep_alive_while_idle: true, + transport_kind: Some(TransportKind::Gateway20), } } diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/mod.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/mod.rs index 790ae170707..1b25c3b9bb8 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/mod.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/mod.rs @@ -20,10 +20,15 @@ pub(crate) mod background_task_manager; pub(crate) mod cosmos_headers; pub(crate) mod cosmos_transport_client; mod emulator; +mod gateway20_dispatch; +/// Gateway 2.0 operation eligibility filter. +pub(crate) mod gateway20_eligibility; +pub(crate) use gateway20_eligibility::is_operation_supported_by_gateway20; pub(crate) mod http_client_factory; pub(crate) mod request_signing; #[cfg(feature = "reqwest")] pub(crate) mod reqwest_transport_client; +pub(crate) mod rntbd; mod sharded_transport; pub(crate) use sharded_transport::EndpointKey; mod tracked_transport; @@ -48,6 +53,9 @@ use self::http_client_factory::DefaultHttpClientFactory; pub(crate) use authorization_policy::generate_authorization; pub(crate) use authorization_policy::AuthorizationContext; pub(crate) use emulator::is_emulator_host; +pub(crate) use gateway20_dispatch::{ + unwrap_response_for_gateway20, wrap_request_for_gateway20, WrapInputs, +}; pub(crate) use tracked_transport::infer_request_sent_status; /// Cosmos DB REST API version. @@ -283,7 +291,7 @@ impl CosmosTransport { } match transport_mode { - TransportMode::Gateway20 if self.connection_pool.is_gateway20_allowed() => { + TransportMode::Gateway20 if !self.connection_pool.gateway20_disabled() => { let transport = match self.dataplane_gateway20_transport.get() { Some(t) => t.clone(), None => { @@ -390,7 +398,7 @@ pub(crate) mod tests { #[test] fn dataplane_transport_uses_gateway20_when_selected() { let pool = ConnectionPoolOptionsBuilder::new() - .with_is_gateway20_allowed(true) + .with_gateway20_disabled(false) .build() .unwrap(); let transport = CosmosTransport::for_tests(pool, TransportHttpVersion::Http2).unwrap(); @@ -406,7 +414,7 @@ pub(crate) mod tests { #[test] fn dataplane_transport_falls_back_to_sharded_gateway_when_endpoint_is_standard() { let pool = ConnectionPoolOptionsBuilder::new() - .with_is_gateway20_allowed(true) + .with_gateway20_disabled(false) .build() .unwrap(); let transport = CosmosTransport::for_tests(pool, TransportHttpVersion::Http2).unwrap(); @@ -422,7 +430,7 @@ pub(crate) mod tests { #[test] fn dataplane_transport_ignores_gateway20_when_gateway20_disabled() { let pool = ConnectionPoolOptionsBuilder::new() - .with_is_gateway20_allowed(false) + .with_gateway20_disabled(true) .build() .unwrap(); let transport = CosmosTransport::for_tests(pool, TransportHttpVersion::Http2).unwrap(); diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/rntbd/mod.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/rntbd/mod.rs new file mode 100644 index 00000000000..b0c2571448c --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/rntbd/mod.rs @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//! Gateway 2.0 RNTBD wire-format support. +//! +//! This module owns in-memory request serialization and response deserialization +//! for RNTBD frames carried by the Gateway 2.0 transport path. + +// Slice 1 intentionally lands the wire-format module before later slices wire it +// into the transport pipeline. +#![allow(dead_code, unused_imports)] + +pub(crate) mod request; +pub(crate) mod response; +pub(crate) mod status; +pub(crate) mod tokens; + +pub(crate) use request::RntbdRequestFrame; +pub(crate) use response::RntbdResponse; +pub(crate) use status::map_rntbd_status_to_cosmos_status; +pub(crate) use tokens::{Token, TokenType, TokenValue}; diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/rntbd/request.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/rntbd/request.rs new file mode 100644 index 00000000000..dffb829ced6 --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/rntbd/request.rs @@ -0,0 +1,224 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//! RNTBD request frame serialization. + +use uuid::Uuid; + +use crate::models::{OperationType, ResourceType}; + +use super::tokens::{ + data_conversion_error, write_uuid_le, RntbdOperationType, RntbdResourceType, Token, +}; + +/// A Gateway 2.0 RNTBD request frame. +/// +/// The body is schema-agnostic raw bytes. When [`body`](Self::body) is present, +/// serialization emits the payload length followed by the payload bytes. +#[derive(Clone, Debug, PartialEq)] +pub(crate) struct RntbdRequestFrame { + /// Resource type encoded into the frame header. + pub(crate) resource_type: ResourceType, + /// Operation type encoded into the frame header. + pub(crate) operation_type: OperationType, + /// Activity identifier encoded as two little-endian `u64` values. + pub(crate) activity_id: Uuid, + /// Metadata token stream. + pub(crate) metadata: Vec, + /// Optional raw request payload. + pub(crate) body: Option>, +} + +impl RntbdRequestFrame { + /// Serializes the request frame to Gateway 2.0 RNTBD bytes. + /// + /// The total length field is inclusive of its own four bytes. Returns + /// [`ErrorKind::DataConversion`] when an input exceeds an RNTBD wire + /// length limit (e.g., a metadata token value longer than the + /// `SmallString` length prefix supports, a body larger than `u32::MAX`, + /// or a frame whose total length exceeds `u32::MAX`). + /// + /// [`ErrorKind::DataConversion`]: azure_core::error::ErrorKind::DataConversion + pub(crate) fn serialize(&self) -> azure_core::Result> { + let metadata_len: usize = self.metadata.iter().map(Token::encoded_len).sum(); + let body_len = self.body.as_ref().map_or(0, |body| 4 + body.len()); + let total_len = 24 + metadata_len + body_len; + let total_len_u32 = u32::try_from(total_len).map_err(|_| { + data_conversion_error(format!( + "RNTBD request frame length {total_len} exceeds u32::MAX" + )) + })?; + + let mut out = Vec::with_capacity(total_len); + out.extend_from_slice(&total_len_u32.to_le_bytes()); + out.extend_from_slice( + &RntbdResourceType::from(self.resource_type) + .value() + .to_le_bytes(), + ); + out.extend_from_slice( + &RntbdOperationType::from(self.operation_type) + .value() + .to_le_bytes(), + ); + write_uuid_le(&mut out, self.activity_id); + + for token in &self.metadata { + token.write_to(&mut out)?; + } + + if let Some(body) = &self.body { + let body_len = u32::try_from(body.len()).map_err(|_| { + data_conversion_error(format!( + "RNTBD request payload length {} exceeds u32::MAX", + body.len() + )) + })?; + out.extend_from_slice(&body_len.to_le_bytes()); + out.extend_from_slice(body); + } + + debug_assert_eq!(out.len(), total_len); + Ok(out) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::driver::transport::rntbd::tokens::{ + data_conversion_error, read_u16_le, read_u32_le, read_uuid_le, RntbdOperationType, + RntbdResourceType, TokenValue, + }; + + #[test] + fn request_frames_round_trip_for_slice_one_operations() { + let operations = [ + OperationType::Create, + OperationType::Read, + OperationType::ReadFeed, + OperationType::Replace, + OperationType::Delete, + OperationType::Upsert, + OperationType::Query, + OperationType::SqlQuery, + OperationType::Head, + OperationType::HeadFeed, + OperationType::Batch, + ]; + + for operation_type in operations { + for body in [None, Some(vec![0x7b, 0x7d])] { + let frame = RntbdRequestFrame { + resource_type: ResourceType::Document, + operation_type, + activity_id: Uuid::from_u128(0x1234_5678_90ab_cdef_0123_4567_89ab_cdef), + metadata: Vec::new(), + body, + }; + + let bytes = frame.serialize().unwrap(); + let parsed = parse_request_for_tests(&bytes, frame.body.is_some()).unwrap(); + + assert_eq!(parsed, frame); + } + } + } + + #[test] + fn query_plan_uses_sql_query_wire_id_until_metadata_rules_land() { + let frame = RntbdRequestFrame { + resource_type: ResourceType::Document, + operation_type: OperationType::QueryPlan, + activity_id: Uuid::nil(), + metadata: Vec::new(), + body: None, + }; + + let bytes = frame.serialize().unwrap(); + let operation_id = u16::from_le_bytes([bytes[6], bytes[7]]); + + // QueryPlan has no distinct Java RNTBD operation ID. Slice 2 will add + // the metadata that disambiguates query-plan requests from SqlQuery. + assert_eq!( + operation_id, + RntbdOperationType::from(OperationType::SqlQuery).value() + ); + } + + #[test] + fn metadata_tokens_are_serialized_between_header_and_body() { + let frame = RntbdRequestFrame { + resource_type: ResourceType::Document, + operation_type: OperationType::Read, + activity_id: Uuid::nil(), + metadata: vec![Token::new(0x00CE, TokenValue::String("account".to_owned()))], + body: None, + }; + + let bytes = frame.serialize().unwrap(); + let parsed = parse_request_for_tests(&bytes, false).unwrap(); + + assert_eq!(parsed, frame); + } + + #[test] + fn serialize_returns_error_when_small_string_exceeds_u8_length_prefix() { + let oversized = "a".repeat(256); + let frame = RntbdRequestFrame { + resource_type: ResourceType::Document, + operation_type: OperationType::Read, + activity_id: Uuid::nil(), + metadata: vec![Token::new(0x0001, TokenValue::SmallString(oversized))], + body: None, + }; + + let err = frame.serialize().unwrap_err(); + assert_eq!(*err.kind(), azure_core::error::ErrorKind::DataConversion); + } + + fn parse_request_for_tests( + bytes: &[u8], + has_body: bool, + ) -> azure_core::Result { + let mut src = bytes; + let total_len = read_u32_le(&mut src)? as usize; + if total_len != bytes.len() { + return Err(data_conversion_error(format!( + "request frame length {total_len} did not match buffer length {}", + bytes.len() + ))); + } + + let resource_type = + ResourceType::try_from(RntbdResourceType::try_from(read_u16_le(&mut src)?)?)?; + let operation_type = + OperationType::try_from(RntbdOperationType::try_from(read_u16_le(&mut src)?)?)?; + let activity_id = read_uuid_le(&mut src)?; + + let mut metadata = Vec::new(); + let body = if has_body { + let payload_len = read_u32_le(&mut src)? as usize; + if src.len() != payload_len { + return Err(data_conversion_error(format!( + "request payload length {payload_len} did not match remaining bytes {}", + src.len() + ))); + } + Some(src.to_vec()) + } else { + while !src.is_empty() { + metadata.push(Token::read_from(&mut src)?); + } + None + }; + + Ok(RntbdRequestFrame { + resource_type, + operation_type, + activity_id, + metadata, + body, + }) + } +} diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/rntbd/response.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/rntbd/response.rs new file mode 100644 index 00000000000..db2cacf3906 --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/rntbd/response.rs @@ -0,0 +1,266 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//! RNTBD response frame deserialization. + +use uuid::Uuid; + +use crate::models::CosmosStatus; + +use super::{ + status::map_rntbd_status_to_cosmos_status, + tokens::{ + data_conversion_error, read_u32_le, read_uuid_le, RntbdResponseToken, Token, TokenValue, + }, +}; + +/// A decoded Gateway 2.0 RNTBD response frame. +/// +/// The body is schema-agnostic raw bytes. Recognized metadata tokens are surfaced +/// as typed optional fields; unknown token IDs are silently consumed. +#[derive(Clone, Debug, PartialEq)] +pub(crate) struct RntbdResponse { + /// Status composed from the frame HTTP status and optional SubStatus token. + pub(crate) status: CosmosStatus, + /// Activity identifier echoed by the service. + pub(crate) activity_id: Uuid, + /// Raw response payload bytes. + pub(crate) body: Vec, + /// Continuation token for feed-style operations. + pub(crate) continuation_token: Option, + /// Entity tag returned by the service. + pub(crate) etag: Option, + /// Retry-after delay in milliseconds. + pub(crate) retry_after_ms: Option, + /// Logical sequence number. + pub(crate) lsn: Option, + /// Request charge in request units. + pub(crate) request_charge: Option, + /// Owner full name metadata. + pub(crate) owner_full_name: Option, + /// Partition key range identifier. + pub(crate) partition_key_range_id: Option, + /// Item logical sequence number. + pub(crate) item_lsn: Option, + /// Global committed logical sequence number. + pub(crate) global_committed_lsn: Option, + /// Transport request identifier. + pub(crate) transport_request_id: Option, + /// Session token for session consistency. + pub(crate) session_token: Option, +} + +impl RntbdResponse { + /// Deserializes a Gateway 2.0 RNTBD response frame. + /// + /// Unknown metadata token IDs are silently consumed when their token type is + /// known. Malformed token values and unknown token type bytes return errors. + /// The Slice 1 frame shape has no separate metadata length, so the parser + /// advances the tracking offset by decoding complete metadata tokens and + /// preserves any trailing bytes shorter than a token header as body bytes. + pub(crate) fn deserialize(bytes: &[u8]) -> azure_core::Result { + let mut src = bytes; + let total_len = read_u32_le(&mut src)? as usize; + if total_len > bytes.len() { + return Err(data_conversion_error(format!( + "RNTBD response length {total_len} exceeds buffer length {}", + bytes.len() + ))); + } + if total_len < 24 { + return Err(data_conversion_error(format!( + "RNTBD response length {total_len} is smaller than the 24-byte header" + ))); + } + + let mut frame = &bytes[4..total_len]; + let http_status = read_u32_le(&mut frame)?; + let activity_id = read_uuid_le(&mut frame)?; + + let mut continuation_token = None; + let mut etag = None; + let mut retry_after_ms = None; + let mut lsn = None; + let mut request_charge = None; + let mut owner_full_name = None; + let mut sub_status = None; + let mut partition_key_range_id = None; + let mut item_lsn = None; + let mut global_committed_lsn = None; + let mut transport_request_id = None; + let mut session_token = None; + + while frame.len() >= 3 { + let token = Token::read_from(&mut frame)?; + match RntbdResponseToken::try_from(token.id) { + Ok(RntbdResponseToken::ContinuationToken) => { + continuation_token = Some(expect_string(token, "ContinuationToken")?); + } + Ok(RntbdResponseToken::ETag) => { + etag = Some(expect_string(token, "ETag")?); + } + Ok(RntbdResponseToken::RetryAfterMilliseconds) => { + retry_after_ms = Some(expect_u32(token, "RetryAfterMilliseconds")?); + } + Ok(RntbdResponseToken::Lsn) => { + lsn = Some(expect_i64(token, "LSN")?); + } + Ok(RntbdResponseToken::RequestCharge) => { + request_charge = Some(expect_f64(token, "RequestCharge")?); + } + Ok(RntbdResponseToken::OwnerFullName) => { + owner_full_name = Some(expect_string(token, "OwnerFullName")?); + } + Ok(RntbdResponseToken::SubStatus) => { + sub_status = Some(expect_u32(token, "SubStatus")?); + } + Ok(RntbdResponseToken::PartitionKeyRangeId) => { + partition_key_range_id = Some(expect_string(token, "PartitionKeyRangeId")?); + } + Ok(RntbdResponseToken::ItemLsn) => { + item_lsn = Some(expect_i64(token, "ItemLSN")?); + } + Ok(RntbdResponseToken::GlobalCommittedLsn) => { + global_committed_lsn = Some(expect_i64(token, "GlobalCommittedLSN")?); + } + Ok(RntbdResponseToken::TransportRequestId) => { + transport_request_id = Some(expect_u32(token, "TransportRequestID")?); + } + Ok(RntbdResponseToken::SessionToken) => { + session_token = Some(expect_string(token, "SessionToken")?); + } + Err(()) => {} + } + } + + Ok(Self { + status: map_rntbd_status_to_cosmos_status(http_status, sub_status), + activity_id, + body: frame.to_vec(), + continuation_token, + etag, + retry_after_ms, + lsn, + request_charge, + owner_full_name, + partition_key_range_id, + item_lsn, + global_committed_lsn, + transport_request_id, + session_token, + }) + } +} + +fn expect_string(token: Token, name: &str) -> azure_core::Result { + match token.value { + TokenValue::String(value) => Ok(value), + _ => Err(unexpected_token_type(name)), + } +} + +fn expect_u32(token: Token, name: &str) -> azure_core::Result { + match token.value { + TokenValue::ULong(value) => Ok(value), + _ => Err(unexpected_token_type(name)), + } +} + +fn expect_i64(token: Token, name: &str) -> azure_core::Result { + match token.value { + TokenValue::LongLong(value) => Ok(value), + _ => Err(unexpected_token_type(name)), + } +} + +fn expect_f64(token: Token, name: &str) -> azure_core::Result { + match token.value { + TokenValue::Double(value) => Ok(value), + _ => Err(unexpected_token_type(name)), + } +} + +fn unexpected_token_type(name: &str) -> azure_core::Error { + data_conversion_error(format!("RNTBD token {name} had an unexpected value type")) +} + +#[cfg(test)] +mod tests { + use super::*; + use azure_core::http::StatusCode; + + use crate::driver::transport::rntbd::tokens::write_uuid_le; + + #[test] + fn unknown_token_id_is_silently_skipped() { + let mut frame = response_header(StatusCode::Ok); + Token::new(0x0015, TokenValue::Double(1.5)) + .write_to(&mut frame) + .unwrap(); + Token::new(0xFFFE, TokenValue::SmallString("hello".to_owned())) + .write_to(&mut frame) + .unwrap(); + Token::new(0x001C, TokenValue::ULong(1002)) + .write_to(&mut frame) + .unwrap(); + patch_total_len(&mut frame); + + let response = RntbdResponse::deserialize(&frame).unwrap(); + + assert_eq!(response.status.status_code(), StatusCode::Ok); + assert_eq!(response.status.sub_status().unwrap().value(), 1002); + assert_eq!(response.request_charge, Some(1.5)); + assert!(response.body.is_empty()); + } + + #[test] + fn total_length_past_buffer_is_rejected() { + let mut frame = response_header(StatusCode::Ok); + let total_len = (frame.len() as u32) + 1; + frame[0..4].copy_from_slice(&total_len.to_le_bytes()); + + let err = RntbdResponse::deserialize(&frame).unwrap_err(); + + assert_eq!(*err.kind(), azure_core::error::ErrorKind::DataConversion); + } + + #[test] + fn trailing_bytes_shorter_than_token_header_are_body() { + let mut frame = response_header(StatusCode::Ok); + frame.extend_from_slice(&[0xAA, 0xBB]); + patch_total_len(&mut frame); + + let response = RntbdResponse::deserialize(&frame).unwrap(); + + assert_eq!(response.body, vec![0xAA, 0xBB]); + } + + #[test] + fn metadata_before_short_body_is_preserved() { + let mut frame = response_header(StatusCode::Ok); + Token::new(0x0015, TokenValue::Double(2.5)) + .write_to(&mut frame) + .unwrap(); + frame.extend_from_slice(&[0xAA, 0xBB]); + patch_total_len(&mut frame); + + let response = RntbdResponse::deserialize(&frame).unwrap(); + + assert_eq!(response.request_charge, Some(2.5)); + assert_eq!(response.body, vec![0xAA, 0xBB]); + } + + fn response_header(status_code: StatusCode) -> Vec { + let mut frame = Vec::new(); + frame.extend_from_slice(&0_u32.to_le_bytes()); + frame.extend_from_slice(&u16::from(status_code).to_le_bytes()); + frame.extend_from_slice(&0_u16.to_le_bytes()); + write_uuid_le(&mut frame, Uuid::nil()); + frame + } + + fn patch_total_len(frame: &mut [u8]) { + let total_len = u32::try_from(frame.len()).unwrap(); + frame[0..4].copy_from_slice(&total_len.to_le_bytes()); + } +} diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/rntbd/status.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/rntbd/status.rs new file mode 100644 index 00000000000..044d43d4483 --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/rntbd/status.rs @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//! RNTBD status mapping helpers. + +use azure_core::http::StatusCode; + +use crate::models::CosmosStatus; + +/// Maps RNTBD frame status fields into a [`CosmosStatus`]. +/// +/// RNTBD carries the HTTP status in the frame header and the Cosmos DB +/// sub-status as an optional metadata token. +pub(crate) fn map_rntbd_status_to_cosmos_status( + http_status: u32, + sub_status: Option, +) -> CosmosStatus { + let status = StatusCode::from(http_status as u16); + let mut cosmos_status = CosmosStatus::new(status); + if let Some(sub_status) = sub_status { + cosmos_status = cosmos_status.with_sub_status(sub_status); + } + cosmos_status +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn maps_http_status_and_sub_status() { + let status = map_rntbd_status_to_cosmos_status(404, Some(1002)); + + assert_eq!(status.status_code(), StatusCode::NotFound); + assert_eq!(status.sub_status().unwrap().value(), 1002); + assert_eq!(status.name(), Some("ReadSessionNotAvailable")); + } + + #[test] + fn unknown_http_status_is_preserved() { + let status = map_rntbd_status_to_cosmos_status(449, None); + + assert_eq!(status.status_code(), StatusCode::UnknownValue(449)); + assert_eq!(status.sub_status(), None); + } +} diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/rntbd/tokens.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/rntbd/tokens.rs new file mode 100644 index 00000000000..125a717b042 --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/rntbd/tokens.rs @@ -0,0 +1,790 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//! RNTBD metadata token codecs and wire ID mappings. + +use azure_core::error::ErrorKind; +use uuid::Uuid; + +use crate::models::{DefaultConsistencyLevel, OperationType, ResourceType}; + +/// The token type byte used by RNTBD metadata tokens. +/// +/// Variable-width types carry their own length prefix in the value payload. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub(crate) enum TokenType { + /// Single unsigned byte. + Byte, + /// Unsigned 16-bit integer encoded little-endian. + UShort, + /// Unsigned 32-bit integer encoded little-endian. + ULong, + /// Signed 32-bit integer encoded little-endian. + Long, + /// Unsigned 64-bit integer encoded little-endian. + ULongLong, + /// Signed 64-bit integer encoded little-endian. + LongLong, + /// UUID encoded in Microsoft GUID byte order. + Guid, + /// UTF-8 string prefixed with an unsigned byte length. + SmallString, + /// UTF-8 string prefixed with an unsigned 16-bit length. + String, + /// UTF-8 string prefixed with an unsigned 32-bit length. + ULongString, + /// Bytes prefixed with an unsigned byte length. + SmallBytes, + /// Bytes prefixed with an unsigned 16-bit length. + Bytes, + /// Bytes prefixed with an unsigned 32-bit length. + ULongBytes, + /// 32-bit floating point value encoded little-endian. + Float, + /// 64-bit floating point value encoded little-endian. + Double, + /// Invalid token type sentinel. + Invalid, +} + +impl TryFrom for TokenType { + type Error = azure_core::Error; + + fn try_from(value: u8) -> azure_core::Result { + match value { + 0x00 => Ok(Self::Byte), + 0x01 => Ok(Self::UShort), + 0x02 => Ok(Self::ULong), + 0x03 => Ok(Self::Long), + 0x04 => Ok(Self::ULongLong), + 0x05 => Ok(Self::LongLong), + 0x06 => Ok(Self::Guid), + 0x07 => Ok(Self::SmallString), + 0x08 => Ok(Self::String), + 0x09 => Ok(Self::ULongString), + 0x0A => Ok(Self::SmallBytes), + 0x0B => Ok(Self::Bytes), + 0x0C => Ok(Self::ULongBytes), + 0x0D => Ok(Self::Float), + 0x0E => Ok(Self::Double), + 0xFF => Ok(Self::Invalid), + other => Err(data_conversion_error(format!( + "unknown RNTBD token type 0x{other:02X}" + ))), + } + } +} + +impl From for u8 { + fn from(value: TokenType) -> Self { + match value { + TokenType::Byte => 0x00, + TokenType::UShort => 0x01, + TokenType::ULong => 0x02, + TokenType::Long => 0x03, + TokenType::ULongLong => 0x04, + TokenType::LongLong => 0x05, + TokenType::Guid => 0x06, + TokenType::SmallString => 0x07, + TokenType::String => 0x08, + TokenType::ULongString => 0x09, + TokenType::SmallBytes => 0x0A, + TokenType::Bytes => 0x0B, + TokenType::ULongBytes => 0x0C, + TokenType::Float => 0x0D, + TokenType::Double => 0x0E, + TokenType::Invalid => 0xFF, + } + } +} + +/// A decoded RNTBD metadata token value. +/// +/// The enum variant determines the value codec used on the wire. +#[derive(Clone, Debug, PartialEq)] +pub(crate) enum TokenValue { + /// Single unsigned byte. + Byte(u8), + /// Unsigned 16-bit integer. + UShort(u16), + /// Unsigned 32-bit integer. + ULong(u32), + /// Signed 32-bit integer. + Long(i32), + /// Unsigned 64-bit integer. + ULongLong(u64), + /// Signed 64-bit integer. + LongLong(i64), + /// UUID in Microsoft GUID token byte order. + Guid(Uuid), + /// UTF-8 string with an unsigned byte length prefix. + SmallString(String), + /// UTF-8 string with an unsigned 16-bit length prefix. + String(String), + /// UTF-8 string with an unsigned 32-bit length prefix. + ULongString(String), + /// Bytes with an unsigned byte length prefix. + SmallBytes(Vec), + /// Bytes with an unsigned 16-bit length prefix. + Bytes(Vec), + /// Bytes with an unsigned 32-bit length prefix. + ULongBytes(Vec), + /// 32-bit floating point value. + Float(f32), + /// 64-bit floating point value. + Double(f64), +} + +impl TokenValue { + fn token_type(&self) -> TokenType { + match self { + Self::Byte(_) => TokenType::Byte, + Self::UShort(_) => TokenType::UShort, + Self::ULong(_) => TokenType::ULong, + Self::Long(_) => TokenType::Long, + Self::ULongLong(_) => TokenType::ULongLong, + Self::LongLong(_) => TokenType::LongLong, + Self::Guid(_) => TokenType::Guid, + Self::SmallString(_) => TokenType::SmallString, + Self::String(_) => TokenType::String, + Self::ULongString(_) => TokenType::ULongString, + Self::SmallBytes(_) => TokenType::SmallBytes, + Self::Bytes(_) => TokenType::Bytes, + Self::ULongBytes(_) => TokenType::ULongBytes, + Self::Float(_) => TokenType::Float, + Self::Double(_) => TokenType::Double, + } + } + + fn encoded_len(&self) -> usize { + match self { + Self::Byte(_) => 1, + Self::UShort(_) => 2, + Self::ULong(_) | Self::Long(_) | Self::Float(_) => 4, + Self::ULongLong(_) | Self::LongLong(_) | Self::Double(_) => 8, + Self::Guid(_) => 16, + Self::SmallString(value) => 1 + value.len(), + Self::String(value) => 2 + value.len(), + Self::ULongString(value) => 4 + value.len(), + Self::SmallBytes(value) => 1 + value.len(), + Self::Bytes(value) => 2 + value.len(), + Self::ULongBytes(value) => 4 + value.len(), + } + } + + fn write_to(&self, out: &mut Vec) -> azure_core::Result<()> { + match self { + Self::Byte(value) => out.push(*value), + Self::UShort(value) => out.extend_from_slice(&value.to_le_bytes()), + Self::ULong(value) => out.extend_from_slice(&value.to_le_bytes()), + Self::Long(value) => out.extend_from_slice(&value.to_le_bytes()), + Self::ULongLong(value) => out.extend_from_slice(&value.to_le_bytes()), + Self::LongLong(value) => out.extend_from_slice(&value.to_le_bytes()), + Self::Guid(value) => write_guid_ms(out, *value), + Self::SmallString(value) => write_len_prefixed_u8(out, value.as_bytes())?, + Self::String(value) => write_len_prefixed_u16(out, value.as_bytes())?, + Self::ULongString(value) => write_len_prefixed_u32(out, value.as_bytes())?, + Self::SmallBytes(value) => write_len_prefixed_u8(out, value)?, + Self::Bytes(value) => write_len_prefixed_u16(out, value)?, + Self::ULongBytes(value) => write_len_prefixed_u32(out, value)?, + Self::Float(value) => out.extend_from_slice(&value.to_le_bytes()), + Self::Double(value) => out.extend_from_slice(&value.to_le_bytes()), + } + Ok(()) + } + + fn read_from(token_type: TokenType, src: &mut &[u8]) -> azure_core::Result { + match token_type { + TokenType::Byte => Ok(Self::Byte(read_u8(src)?)), + TokenType::UShort => Ok(Self::UShort(read_u16_le(src)?)), + TokenType::ULong => Ok(Self::ULong(read_u32_le(src)?)), + TokenType::Long => Ok(Self::Long(read_i32_le(src)?)), + TokenType::ULongLong => Ok(Self::ULongLong(read_u64_le(src)?)), + TokenType::LongLong => Ok(Self::LongLong(read_i64_le(src)?)), + TokenType::Guid => Ok(Self::Guid(read_guid_ms(src)?)), + TokenType::SmallString => { + let len = read_u8(src)? as usize; + Ok(Self::SmallString(read_utf8(src, len)?)) + } + TokenType::String => { + let len = read_u16_le(src)? as usize; + Ok(Self::String(read_utf8(src, len)?)) + } + TokenType::ULongString => { + let len = read_u32_le(src)? as usize; + Ok(Self::ULongString(read_utf8(src, len)?)) + } + TokenType::SmallBytes => { + let len = read_u8(src)? as usize; + Ok(Self::SmallBytes( + read_exact(src, len, "small bytes")?.to_vec(), + )) + } + TokenType::Bytes => { + let len = read_u16_le(src)? as usize; + Ok(Self::Bytes(read_exact(src, len, "bytes")?.to_vec())) + } + TokenType::ULongBytes => { + let len = read_u32_le(src)? as usize; + Ok(Self::ULongBytes( + read_exact(src, len, "ulong bytes")?.to_vec(), + )) + } + TokenType::Float => Ok(Self::Float(f32::from_le_bytes(read_array(src)?))), + TokenType::Double => Ok(Self::Double(f64::from_le_bytes(read_array(src)?))), + TokenType::Invalid => Err(data_conversion_error( + "invalid RNTBD token type sentinel encountered", + )), + } + } +} + +/// A single RNTBD metadata token. +/// +/// Tokens are encoded as a two-byte token ID, a one-byte [`TokenType`], and the +/// value bytes for that type. +#[derive(Clone, Debug, PartialEq)] +pub(crate) struct Token { + /// Token identifier from the RNTBD header table. + pub(crate) id: u16, + /// Decoded token value. + pub(crate) value: TokenValue, +} + +impl Token { + /// Creates a metadata token from an ID and typed value. + pub(crate) fn new(id: u16, value: TokenValue) -> Self { + Self { id, value } + } + + pub(crate) fn authorization_token(value: String) -> Self { + Self::new( + RntbdRequestToken::AuthorizationToken.into(), + TokenValue::String(value), + ) + } + + pub(crate) fn payload_present(value: bool) -> Self { + Self::new( + RntbdRequestToken::PayloadPresent.into(), + TokenValue::Byte(u8::from(value)), + ) + } + + pub(crate) fn date(value: String) -> Self { + Self::new( + RntbdRequestToken::Date.into(), + TokenValue::SmallString(value), + ) + } + + pub(crate) fn consistency_level(value: DefaultConsistencyLevel) -> Self { + let value = match value { + DefaultConsistencyLevel::Strong => 0x00, + DefaultConsistencyLevel::BoundedStaleness => 0x01, + DefaultConsistencyLevel::Session => 0x02, + DefaultConsistencyLevel::Eventual => 0x03, + DefaultConsistencyLevel::ConsistentPrefix => 0x04, + }; + Self::new( + RntbdRequestToken::ConsistencyLevel.into(), + TokenValue::Byte(value), + ) + } + + pub(crate) fn database_name(value: String) -> Self { + Self::new( + RntbdRequestToken::DatabaseName.into(), + TokenValue::String(value), + ) + } + + pub(crate) fn collection_name(value: String) -> Self { + Self::new( + RntbdRequestToken::CollectionName.into(), + TokenValue::String(value), + ) + } + + pub(crate) fn document_name(value: String) -> Self { + Self::new( + RntbdRequestToken::DocumentName.into(), + TokenValue::String(value), + ) + } + + pub(crate) fn transport_request_id(value: u32) -> Self { + Self::new( + RntbdRequestToken::TransportRequestId.into(), + TokenValue::ULong(value), + ) + } + + pub(crate) fn effective_partition_key(value: Vec) -> Self { + Self::new( + RntbdRequestToken::EffectivePartitionKey.into(), + TokenValue::Bytes(value), + ) + } + + pub(crate) fn sdk_supported_capabilities(value: u32) -> Self { + Self::new( + RntbdRequestToken::SDKSupportedCapabilities.into(), + TokenValue::ULong(value), + ) + } + + pub(crate) fn global_database_account_name(value: String) -> Self { + Self::new( + RntbdRequestToken::GlobalDatabaseAccountName.into(), + TokenValue::String(value), + ) + } + + /// Pagination cursor echoed back to the proxy on subsequent feed/query + /// requests. Wire format matches Java's `RntbdRequestHeader.ContinuationToken` + /// (ID 0x0006, string) — the SDK passes the value through unchanged so + /// the backend can resume from the previous offset. + pub(crate) fn continuation_token(value: String) -> Self { + Self::new( + RntbdRequestToken::ContinuationToken.into(), + TokenValue::String(value), + ) + } + + /// Returns the number of bytes this token occupies on the wire. + pub(super) fn encoded_len(&self) -> usize { + 2 + 1 + self.value.encoded_len() + } + + /// Writes this token to the output buffer. + /// + /// Returns an error if the token value exceeds the wire encoding's length + /// limits (e.g., a `SmallString` longer than 255 bytes). + pub(super) fn write_to(&self, out: &mut Vec) -> azure_core::Result<()> { + out.extend_from_slice(&self.id.to_le_bytes()); + out.push(self.value.token_type().into()); + self.value.write_to(out) + } + + /// Reads a token from the input slice and advances the slice. + pub(super) fn read_from(src: &mut &[u8]) -> azure_core::Result { + let id = read_u16_le(src)?; + let token_type = TokenType::try_from(read_u8(src)?)?; + let value = TokenValue::read_from(token_type, src)?; + Ok(Self { id, value }) + } +} + +/// RNTBD request metadata token IDs used by Gateway 2.0 dispatch. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub(crate) enum RntbdRequestToken { + AuthorizationToken, + PayloadPresent, + Date, + ContinuationToken, + ConsistencyLevel, + DatabaseName, + CollectionName, + DocumentName, + TransportRequestId, + EffectivePartitionKey, + SDKSupportedCapabilities, + GlobalDatabaseAccountName, +} + +impl TryFrom for RntbdRequestToken { + type Error = (); + + fn try_from(value: u16) -> Result { + match value { + 0x0001 => Ok(Self::AuthorizationToken), + 0x0002 => Ok(Self::PayloadPresent), + 0x0003 => Ok(Self::Date), + 0x0006 => Ok(Self::ContinuationToken), + 0x0010 => Ok(Self::ConsistencyLevel), + 0x0015 => Ok(Self::DatabaseName), + 0x0016 => Ok(Self::CollectionName), + 0x0017 => Ok(Self::DocumentName), + 0x004D => Ok(Self::TransportRequestId), + 0x005A => Ok(Self::EffectivePartitionKey), + 0x00A2 => Ok(Self::SDKSupportedCapabilities), + 0x00CE => Ok(Self::GlobalDatabaseAccountName), + _ => Err(()), + } + } +} + +impl From for u16 { + fn from(value: RntbdRequestToken) -> Self { + match value { + RntbdRequestToken::AuthorizationToken => 0x0001, + RntbdRequestToken::PayloadPresent => 0x0002, + RntbdRequestToken::Date => 0x0003, + RntbdRequestToken::ContinuationToken => 0x0006, + RntbdRequestToken::ConsistencyLevel => 0x0010, + RntbdRequestToken::DatabaseName => 0x0015, + RntbdRequestToken::CollectionName => 0x0016, + RntbdRequestToken::DocumentName => 0x0017, + RntbdRequestToken::TransportRequestId => 0x004D, + RntbdRequestToken::EffectivePartitionKey => 0x005A, + RntbdRequestToken::SDKSupportedCapabilities => 0x00A2, + RntbdRequestToken::GlobalDatabaseAccountName => 0x00CE, + } + } +} + +/// RNTBD response metadata token IDs recognized by Slice 1. +pub(super) enum RntbdResponseToken { + /// Continuation token. + ContinuationToken, + /// Entity tag. + ETag, + /// Retry-after delay in milliseconds. + RetryAfterMilliseconds, + /// Logical sequence number. + Lsn, + /// Request charge in request units. + RequestCharge, + /// Owner full name. + OwnerFullName, + /// Cosmos DB sub-status code. + SubStatus, + /// Partition key range identifier. + PartitionKeyRangeId, + /// Item logical sequence number. + ItemLsn, + /// Global committed logical sequence number. + GlobalCommittedLsn, + /// Transport request identifier. + TransportRequestId, + /// Session token. + SessionToken, +} + +impl TryFrom for RntbdResponseToken { + type Error = (); + + fn try_from(value: u16) -> Result { + match value { + 0x0003 => Ok(Self::ContinuationToken), + 0x0004 => Ok(Self::ETag), + 0x000C => Ok(Self::RetryAfterMilliseconds), + 0x0013 => Ok(Self::Lsn), + 0x0015 => Ok(Self::RequestCharge), + 0x0017 => Ok(Self::OwnerFullName), + 0x001C => Ok(Self::SubStatus), + 0x0021 => Ok(Self::PartitionKeyRangeId), + 0x0032 => Ok(Self::ItemLsn), + 0x0029 => Ok(Self::GlobalCommittedLsn), + 0x0035 => Ok(Self::TransportRequestId), + 0x003E => Ok(Self::SessionToken), + _ => Err(()), + } + } +} + +/// RNTBD resource type wire ID. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub(super) struct RntbdResourceType(u16); + +impl RntbdResourceType { + /// Returns the underlying RNTBD resource type ID. + pub(super) fn value(self) -> u16 { + self.0 + } +} + +impl From for RntbdResourceType { + fn from(value: ResourceType) -> Self { + let id = match value { + ResourceType::DatabaseAccount => 0x0014, + ResourceType::Database => 0x0001, + ResourceType::DocumentCollection => 0x0002, + ResourceType::Document => 0x0003, + ResourceType::StoredProcedure => 0x0007, + ResourceType::Trigger => 0x0009, + ResourceType::UserDefinedFunction => 0x000A, + ResourceType::PartitionKeyRange => 0x0016, + ResourceType::Offer => 0x000F, + }; + Self(id) + } +} + +impl TryFrom for RntbdResourceType { + type Error = azure_core::Error; + + fn try_from(value: u16) -> azure_core::Result { + match value { + 0x0014 | 0x0001 | 0x0002 | 0x0003 | 0x0007 | 0x0009 | 0x000A | 0x0016 | 0x000F => { + Ok(Self(value)) + } + other => Err(data_conversion_error(format!( + "unknown RNTBD resource type 0x{other:04X}" + ))), + } + } +} + +impl TryFrom for ResourceType { + type Error = azure_core::Error; + + fn try_from(value: RntbdResourceType) -> azure_core::Result { + match value.0 { + 0x0014 => Ok(Self::DatabaseAccount), + 0x0001 => Ok(Self::Database), + 0x0002 => Ok(Self::DocumentCollection), + 0x0003 => Ok(Self::Document), + 0x0007 => Ok(Self::StoredProcedure), + 0x0009 => Ok(Self::Trigger), + 0x000A => Ok(Self::UserDefinedFunction), + 0x0016 => Ok(Self::PartitionKeyRange), + 0x000F => Ok(Self::Offer), + _ => Err(data_conversion_error("unknown RNTBD resource type")), + } + } +} + +/// RNTBD operation type wire ID. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub(super) struct RntbdOperationType(u16); + +impl RntbdOperationType { + /// Returns the underlying RNTBD operation type ID. + pub(super) fn value(self) -> u16 { + self.0 + } +} + +impl From for RntbdOperationType { + fn from(value: OperationType) -> Self { + let id = match value { + OperationType::Create => 0x0001, + OperationType::Read => 0x0003, + OperationType::ReadFeed => 0x0004, + OperationType::Delete => 0x0005, + OperationType::Replace => 0x0006, + OperationType::Execute => 0x0008, + OperationType::SqlQuery => 0x0009, + // QueryPlan has no distinct Gateway 2.0 wire ID; Java encodes it as SqlQuery + // with additional metadata that lands in a later slice. + OperationType::QueryPlan => 0x0009, + OperationType::Query => 0x000F, + OperationType::Head => 0x0011, + OperationType::HeadFeed => 0x0012, + OperationType::Upsert => 0x0013, + OperationType::Batch => 0x0025, + }; + Self(id) + } +} + +impl TryFrom for RntbdOperationType { + type Error = azure_core::Error; + + fn try_from(value: u16) -> azure_core::Result { + match value { + 0x0001 | 0x0003 | 0x0004 | 0x0005 | 0x0006 | 0x0008 | 0x0009 | 0x000F | 0x0011 + | 0x0012 | 0x0013 | 0x0025 => Ok(Self(value)), + other => Err(data_conversion_error(format!( + "unknown RNTBD operation type 0x{other:04X}" + ))), + } + } +} + +impl TryFrom for OperationType { + type Error = azure_core::Error; + + fn try_from(value: RntbdOperationType) -> azure_core::Result { + match value.0 { + 0x0001 => Ok(Self::Create), + 0x0003 => Ok(Self::Read), + 0x0004 => Ok(Self::ReadFeed), + 0x0005 => Ok(Self::Delete), + 0x0006 => Ok(Self::Replace), + 0x0008 => Ok(Self::Execute), + 0x0009 => Ok(Self::SqlQuery), + 0x000F => Ok(Self::Query), + 0x0011 => Ok(Self::Head), + 0x0012 => Ok(Self::HeadFeed), + 0x0013 => Ok(Self::Upsert), + 0x0025 => Ok(Self::Batch), + _ => Err(data_conversion_error("unknown RNTBD operation type")), + } + } +} + +/// Creates a data-conversion error for malformed RNTBD input. +pub(super) fn data_conversion_error(message: impl Into) -> azure_core::Error { + azure_core::Error::with_message(ErrorKind::DataConversion, message.into()) +} + +/// Writes a UUID using the Gateway 2.0 activity ID byte order. +/// +/// The wire form is the UUID most-significant 64 bits in little-endian order +/// followed by the least-significant 64 bits in little-endian order. +pub(super) fn write_uuid_le(out: &mut Vec, id: Uuid) { + let value = id.as_u128(); + let msb = (value >> 64) as u64; + let lsb = value as u64; + out.extend_from_slice(&msb.to_le_bytes()); + out.extend_from_slice(&lsb.to_le_bytes()); +} + +/// Reads a UUID using the Gateway 2.0 activity ID byte order. +pub(super) fn read_uuid_le(src: &mut &[u8]) -> azure_core::Result { + let msb = read_u64_le(src)?; + let lsb = read_u64_le(src)?; + Ok(Uuid::from_u128(((msb as u128) << 64) | lsb as u128)) +} + +/// Reads an unsigned byte from the input slice. +pub(super) fn read_u8(src: &mut &[u8]) -> azure_core::Result { + Ok(read_exact(src, 1, "u8")?[0]) +} + +/// Reads an unsigned 16-bit little-endian integer from the input slice. +pub(super) fn read_u16_le(src: &mut &[u8]) -> azure_core::Result { + Ok(u16::from_le_bytes(read_array(src)?)) +} + +/// Reads an unsigned 32-bit little-endian integer from the input slice. +pub(super) fn read_u32_le(src: &mut &[u8]) -> azure_core::Result { + Ok(u32::from_le_bytes(read_array(src)?)) +} + +/// Reads an unsigned 64-bit little-endian integer from the input slice. +pub(super) fn read_u64_le(src: &mut &[u8]) -> azure_core::Result { + Ok(u64::from_le_bytes(read_array(src)?)) +} + +fn read_i32_le(src: &mut &[u8]) -> azure_core::Result { + Ok(i32::from_le_bytes(read_array(src)?)) +} + +fn read_i64_le(src: &mut &[u8]) -> azure_core::Result { + Ok(i64::from_le_bytes(read_array(src)?)) +} + +fn read_array(src: &mut &[u8]) -> azure_core::Result<[u8; N]> { + let bytes = read_exact(src, N, "fixed-width value")?; + let mut out = [0_u8; N]; + out.copy_from_slice(bytes); + Ok(out) +} + +fn read_exact<'a>(src: &mut &'a [u8], len: usize, context: &str) -> azure_core::Result<&'a [u8]> { + if src.len() < len { + return Err(data_conversion_error(format!( + "RNTBD {context} needs {len} bytes but only {} remain", + src.len() + ))); + } + let (head, tail) = src.split_at(len); + *src = tail; + Ok(head) +} + +fn read_utf8(src: &mut &[u8], len: usize) -> azure_core::Result { + let bytes = read_exact(src, len, "UTF-8 string")?; + String::from_utf8(bytes.to_vec()) + .map_err(|e| azure_core::Error::new(ErrorKind::DataConversion, e)) +} + +fn write_len_prefixed_u8(out: &mut Vec, bytes: &[u8]) -> azure_core::Result<()> { + let len = u8::try_from(bytes.len()).map_err(|_| { + data_conversion_error(format!( + "RNTBD value length {} exceeds u8 length-prefix maximum (255)", + bytes.len() + )) + })?; + out.push(len); + out.extend_from_slice(bytes); + Ok(()) +} + +fn write_len_prefixed_u16(out: &mut Vec, bytes: &[u8]) -> azure_core::Result<()> { + let len = u16::try_from(bytes.len()).map_err(|_| { + data_conversion_error(format!( + "RNTBD value length {} exceeds u16 length-prefix maximum (65535)", + bytes.len() + )) + })?; + out.extend_from_slice(&len.to_le_bytes()); + out.extend_from_slice(bytes); + Ok(()) +} + +fn write_len_prefixed_u32(out: &mut Vec, bytes: &[u8]) -> azure_core::Result<()> { + let len = u32::try_from(bytes.len()).map_err(|_| { + data_conversion_error(format!( + "RNTBD value length {} exceeds u32 length-prefix maximum (4294967295)", + bytes.len() + )) + })?; + out.extend_from_slice(&len.to_le_bytes()); + out.extend_from_slice(bytes); + Ok(()) +} + +fn write_guid_ms(out: &mut Vec, id: Uuid) { + let (data1, data2, data3, data4) = id.as_fields(); + out.extend_from_slice(&data1.to_le_bytes()); + out.extend_from_slice(&data2.to_le_bytes()); + out.extend_from_slice(&data3.to_le_bytes()); + out.extend_from_slice(data4); +} + +fn read_guid_ms(src: &mut &[u8]) -> azure_core::Result { + let data1 = u32::from_le_bytes(read_array(src)?); + let data2 = u16::from_le_bytes(read_array(src)?); + let data3 = u16::from_le_bytes(read_array(src)?); + let data4 = read_array(src)?; + Ok(Uuid::from_fields(data1, data2, data3, &data4)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn activity_id_uses_msb_lsb_little_endian_order() { + let id = Uuid::parse_str("0a1b2c3d-4e5f-6789-abcd-ef0123456789").unwrap(); + let mut bytes = Vec::new(); + + write_uuid_le(&mut bytes, id); + + assert_eq!( + bytes, + vec![ + 0x89, 0x67, 0x5f, 0x4e, 0x3d, 0x2c, 0x1b, 0x0a, 0x89, 0x67, 0x45, 0x23, 0x01, 0xef, + 0xcd, 0xab, + ] + ); + + let mut src = bytes.as_slice(); + let decoded = read_uuid_le(&mut src).unwrap(); + assert_eq!(decoded, id); + assert!(src.is_empty()); + } + + #[test] + fn invalid_token_type_sentinel_is_rejected() { + let mut src = [0x01, 0x00, 0xFF].as_slice(); + + let err = Token::read_from(&mut src).unwrap_err(); + + assert_eq!(*err.kind(), ErrorKind::DataConversion); + } + + #[test] + fn small_string_rejects_length_past_remaining_buffer() { + let mut src = [0x01, 0x00, 0x07, 0x05, b'h', b'i'].as_slice(); + + let err = Token::read_from(&mut src).unwrap_err(); + + assert_eq!(*err.kind(), ErrorKind::DataConversion); + } +} diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/transport_pipeline.rs b/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/transport_pipeline.rs index ff0cdde4470..fe69e21ba8d 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/transport_pipeline.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/src/driver/transport/transport_pipeline.rs @@ -31,11 +31,13 @@ use crate::{ use super::{ adaptive_transport::AdaptiveTransport, cosmos_headers::apply_cosmos_headers, cosmos_transport_client::HttpRequest, infer_request_sent_status, request_signing::sign_request, - sharded_transport::EndpointKey, + sharded_transport::EndpointKey, unwrap_response_for_gateway20, wrap_request_for_gateway20, + WrapInputs, }; use crate::driver::pipeline::components::{ - ThrottleAction, ThrottleRetryState, TransportOutcome, TransportRequest, TransportResult, + ThrottleAction, ThrottleRetryState, TransportMode, TransportOutcome, TransportRequest, + TransportResult, }; /// Cosmos DB retry-after header (milliseconds). @@ -153,6 +155,8 @@ pub(crate) struct TransportPipelineContext<'a> { /// Computed once by the operation pipeline from the routing-level endpoint /// so the transport pipeline doesn't need to allocate a `String` per attempt. pub endpoint_key: EndpointKey, + /// Global database account name used by Gateway 2.0 request wrapping. + pub account_name: Option, } /// Executes a single transport attempt. @@ -251,6 +255,25 @@ pub(crate) async fn execute_transport_pipeline( }; } + let should_unwrap_gateway20 = request.transport_mode == TransportMode::Gateway20; + if should_unwrap_gateway20 { + let wrap_inputs = WrapInputs { + auth_context: &request.auth_context, + operation_type: request.operation_type, + resource_type: request.auth_context.resource_type, + partition_key: request.partition_key.as_ref(), + partition_key_definition: request.partition_key_definition.as_ref(), + effective_consistency: request.effective_consistency, + account_name: ctx.account_name.as_deref(), + }; + match wrap_request_for_gateway20(&http_request, &wrap_inputs) { + Ok(wrapped_request) => http_request = wrapped_request, + Err(e) => { + return gateway20_wrap_error_result(e, request_handle, diagnostics); + } + } + } + // Record transport start event diagnostics.add_event( request_handle, @@ -274,6 +297,7 @@ pub(crate) async fn execute_transport_pipeline( diagnostics, excluded_shard_id.take(), endpoint_key, + should_unwrap_gateway20, ) .await; @@ -369,6 +393,7 @@ fn deadline_exceeded_result(request_sent: RequestSentStatus) -> TransportResult TransportResult::deadline_exceeded(request_sent) } +#[allow(clippy::too_many_arguments)] async fn execute_http_attempt( http_request: &HttpRequest, transport: &AdaptiveTransport, @@ -377,6 +402,7 @@ async fn execute_http_attempt( diagnostics: &mut DiagnosticsContextBuilder, excluded_shard_id: Option, endpoint_key: &EndpointKey, + should_unwrap_gateway20: bool, ) -> ExecutedTransportAttempt { if let Some(timeout_duration) = per_request_timeout { // Pre-select the shard so we know which shard the request was dispatched @@ -405,9 +431,12 @@ async fn execute_http_attempt( pin_mut!(timeout_future); return match futures::future::select(transport_future, timeout_future).await { - Either::Left((attempt_result, _)) => { - finalize_http_attempt(attempt_result, request_handle, diagnostics) - } + Either::Left((attempt_result, _)) => finalize_http_attempt( + attempt_result, + request_handle, + diagnostics, + should_unwrap_gateway20, + ), Either::Right((_, _remaining_transport_future)) => { diagnostics.add_event( request_handle, @@ -432,7 +461,12 @@ async fn execute_http_attempt( None, ) .await; - finalize_http_attempt(attempt_result, request_handle, diagnostics) + finalize_http_attempt( + attempt_result, + request_handle, + diagnostics, + should_unwrap_gateway20, + ) } async fn execute_http_attempt_future( @@ -472,6 +506,7 @@ fn finalize_http_attempt( attempt_result: HttpAttemptResult, request_handle: RequestHandle, diagnostics: &mut DiagnosticsContextBuilder, + should_unwrap_gateway20: bool, ) -> ExecutedTransportAttempt { match attempt_result { HttpAttemptResult::Response { @@ -488,6 +523,36 @@ fn finalize_http_attempt( if let Some(shard_diagnostics) = shard_diagnostics.clone() { diagnostics.set_transport_shard(request_handle, shard_diagnostics); } + + let (status_code, headers, body) = if should_unwrap_gateway20 + && status_code == azure_core::http::StatusCode::Ok + { + match unwrap_response_for_gateway20(super::cosmos_transport_client::HttpResponse { + status: u16::from(status_code), + headers, + body, + }) { + Ok(response) => ( + azure_core::http::StatusCode::from(response.status), + response.headers, + response.body, + ), + Err(error) => { + return ExecutedTransportAttempt { + result: gateway20_unwrap_error_result( + error, + request_handle, + diagnostics, + ), + shard_id, + shard_diagnostics, + }; + } + } + } else { + (status_code, headers, body) + }; + ExecutedTransportAttempt { result: map_http_response_payload( status_code, @@ -548,6 +613,56 @@ fn format_transport_error_details(error: &azure_core::Error) -> String { crate::driver::error_chain_summary(error) } +fn gateway20_wrap_error_result( + error: azure_core::Error, + request_handle: RequestHandle, + diagnostics: &mut DiagnosticsContextBuilder, +) -> TransportResult { + let status = CosmosStatus::CLIENT_GENERATED_400; + let error_details = format_transport_error_details(&error); + diagnostics.fail_transport_request( + request_handle, + error_details, + RequestSentStatus::NotSent, + status, + ); + + TransportResult { + outcome: TransportOutcome::TransportError { + status, + error, + request_sent: RequestSentStatus::NotSent, + }, + } +} + +fn gateway20_unwrap_error_result( + error: azure_core::Error, + request_handle: RequestHandle, + diagnostics: &mut DiagnosticsContextBuilder, +) -> TransportResult { + let status = CosmosStatus::TRANSPORT_GENERATED_503; + let error_details = format_transport_error_details(&error); + diagnostics.add_event( + request_handle, + RequestEvent::new(RequestEventType::TransportFailed).with_details(error_details.clone()), + ); + diagnostics.fail_transport_request( + request_handle, + error_details, + RequestSentStatus::Sent, + status, + ); + + TransportResult { + outcome: TransportOutcome::TransportError { + status, + error, + request_sent: RequestSentStatus::Sent, + }, + } +} + fn transport_error_result( error: azure_core::Error, headers_received: bool, @@ -660,6 +775,7 @@ fn map_http_response_payload( mod tests { use super::*; use std::{ + collections::VecDeque, sync::{Arc, Mutex}, time::Duration, }; @@ -667,7 +783,7 @@ mod tests { use async_trait::async_trait; use crate::{ - diagnostics::DiagnosticsContextBuilder, + diagnostics::{DiagnosticsContextBuilder, RequestSentStatus}, driver::{ routing::CosmosEndpoint, transport::{ @@ -678,7 +794,7 @@ mod tests { http_client_factory::{HttpClientConfig, HttpClientFactory}, }, }, - models::{ActivityId, Credential, ResourceType}, + models::{ActivityId, Credential, DefaultConsistencyLevel, OperationType, ResourceType}, options::DiagnosticsOptions, }; @@ -889,6 +1005,11 @@ mod tests { let request = TransportRequest { method: azure_core::http::Method::Get, endpoint: endpoint.clone(), + transport_mode: TransportMode::Gateway, + operation_type: OperationType::Read, + partition_key: None, + partition_key_definition: None, + effective_consistency: DefaultConsistencyLevel::Session, url: endpoint.url().clone(), headers: azure_core::http::headers::Headers::new(), body: None, @@ -918,6 +1039,7 @@ mod tests { pipeline_type: PipelineType::Metadata, transport_security: TransportSecurity::Secure, endpoint_key: endpoint.endpoint_key(), + account_name: None, }, &mut diagnostics, ) @@ -1027,6 +1149,11 @@ mod tests { TransportRequest { method: azure_core::http::Method::Get, endpoint: endpoint.clone(), + transport_mode: TransportMode::Gateway, + operation_type: OperationType::Read, + partition_key: None, + partition_key_definition: None, + effective_consistency: DefaultConsistencyLevel::Session, url: endpoint.url().clone(), headers: azure_core::http::headers::Headers::new(), body: None, @@ -1063,6 +1190,7 @@ mod tests { pipeline_type: PipelineType::DataPlane, transport_security: TransportSecurity::Secure, endpoint_key: test_endpoint_key(), + account_name: None, }, &mut diagnostics, ) @@ -1111,6 +1239,7 @@ mod tests { pipeline_type: PipelineType::DataPlane, transport_security: TransportSecurity::Secure, endpoint_key: test_endpoint_key(), + account_name: None, }, &mut diagnostics, ) @@ -1148,6 +1277,7 @@ mod tests { pipeline_type: PipelineType::DataPlane, transport_security: TransportSecurity::Secure, endpoint_key: test_endpoint_key(), + account_name: None, }, &mut diagnostics, ) @@ -1183,6 +1313,7 @@ mod tests { pipeline_type: PipelineType::DataPlane, transport_security: TransportSecurity::Secure, endpoint_key: test_endpoint_key(), + account_name: None, }, &mut diagnostics, ) @@ -1207,6 +1338,382 @@ mod tests { assert_eq!(requests[0].request_sent(), RequestSentStatus::NotSent); } + #[derive(Debug)] + struct Gateway20MockTransportClient { + responses: Mutex>, + requests: Mutex>, + } + + impl Gateway20MockTransportClient { + fn new(responses: Vec) -> Self { + Self { + responses: Mutex::new(responses.into()), + requests: Mutex::new(Vec::new()), + } + } + + fn requests(&self) -> Vec { + self.requests.lock().unwrap().clone() + } + } + + #[async_trait] + impl TransportClient for Gateway20MockTransportClient { + async fn send(&self, request: &HttpRequest) -> Result { + self.requests.lock().unwrap().push(request.clone()); + self.responses.lock().unwrap().pop_front().ok_or_else(|| { + TransportError::new( + azure_core::Error::with_message(ErrorKind::Other, "no response queued"), + RequestSentStatus::Unknown, + ) + }) + } + } + + const GATEWAY20_ACTIVITY_ID: &str = "00112233-4455-6677-8899-aabbccddeeff"; + + fn gateway20_transport_request(transport_mode: TransportMode) -> TransportRequest { + let endpoint = CosmosEndpoint::global( + url::Url::parse("https://test-thin.documents.azure.com:444/").unwrap(), + ); + let mut headers = azure_core::http::headers::Headers::new(); + headers.insert("x-ms-activity-id", GATEWAY20_ACTIVITY_ID); + TransportRequest { + method: azure_core::http::Method::Get, + endpoint: endpoint.clone(), + transport_mode, + operation_type: OperationType::Read, + partition_key: None, + partition_key_definition: None, + effective_consistency: DefaultConsistencyLevel::Session, + url: endpoint.url().clone(), + headers, + body: None, + auth_context: super::super::AuthorizationContext::new( + azure_core::http::Method::Get, + ResourceType::Document, + "dbs/db1/colls/coll1/docs/doc1", + ), + execution_context: ExecutionContext::Initial, + deadline: None, + } + } + + fn gateway20_context<'a>( + client: &'a AdaptiveTransport, + endpoint_key: EndpointKey, + account_name: Option, + credential: &'a Credential, + user_agent: &'a azure_core::http::headers::HeaderValue, + ) -> TransportPipelineContext<'a> { + TransportPipelineContext { + transport: client, + allow_sent_transport_retry: false, + credential, + user_agent, + pipeline_type: PipelineType::DataPlane, + transport_security: TransportSecurity::Secure, + endpoint_key, + account_name, + } + } + + fn gateway20_diagnostics() -> DiagnosticsContextBuilder { + DiagnosticsContextBuilder::new( + ActivityId::from_string(GATEWAY20_ACTIVITY_ID.to_owned()), + Arc::new(DiagnosticsOptions::default()), + ) + } + + #[tokio::test] + async fn gateway20_pipeline_wraps_request_and_unwraps_success_response() { + let mock = Arc::new(Gateway20MockTransportClient::new(vec![gateway20_response( + 200, + |_| {}, + b"{}", + )])); + let client = AdaptiveTransport::Gateway(mock.clone()); + let request = gateway20_transport_request(TransportMode::Gateway20); + let endpoint_key = request.endpoint.endpoint_key(); + let credential = Credential::from(azure_core::credentials::Secret::new("dGVzdA==")); + let user_agent = azure_core::http::headers::HeaderValue::from_static("test-agent"); + let mut diagnostics = gateway20_diagnostics(); + + let result = execute_transport_pipeline( + request, + &gateway20_context( + &client, + endpoint_key, + Some("account".to_owned()), + &credential, + &user_agent, + ), + &mut diagnostics, + ) + .await; + + match result.outcome { + TransportOutcome::Success { status, body, .. } => { + assert_eq!(status.status_code(), azure_core::http::StatusCode::Ok); + assert_eq!(body, b"{}".to_vec()); + } + other => panic!("expected success, got {other:?}"), + } + let captured = mock.requests(); + assert_eq!(captured.len(), 1); + assert_eq!(captured[0].method, azure_core::http::Method::Post); + assert_eq!( + captured[0] + .headers + .get_optional_str(&azure_core::http::headers::AUTHORIZATION), + None + ); + assert_eq!( + captured[0] + .headers + .get_optional_str(&azure_core::http::headers::USER_AGENT), + Some("test-agent") + ); + assert!(captured[0] + .body + .as_ref() + .is_some_and(|body| !body.is_empty())); + } + + #[tokio::test] + async fn gateway20_pipeline_leaves_standard_gateway_request_unwrapped() { + let mock = Arc::new(Gateway20MockTransportClient::new(vec![HttpResponse { + status: 200, + headers: azure_core::http::headers::Headers::new(), + body: b"plain".to_vec(), + }])); + let client = AdaptiveTransport::Gateway(mock.clone()); + let request = gateway20_transport_request(TransportMode::Gateway); + let endpoint_key = request.endpoint.endpoint_key(); + let credential = Credential::from(azure_core::credentials::Secret::new("dGVzdA==")); + let user_agent = azure_core::http::headers::HeaderValue::from_static("test-agent"); + let mut diagnostics = gateway20_diagnostics(); + + let result = execute_transport_pipeline( + request, + &gateway20_context(&client, endpoint_key, None, &credential, &user_agent), + &mut diagnostics, + ) + .await; + + match result.outcome { + TransportOutcome::Success { body, .. } => assert_eq!(body, b"plain".to_vec()), + other => panic!("expected success, got {other:?}"), + } + let captured = mock.requests(); + assert_eq!(captured.len(), 1); + assert_eq!(captured[0].method, azure_core::http::Method::Get); + assert!(captured[0] + .headers + .get_optional_str(&azure_core::http::headers::AUTHORIZATION) + .is_some()); + } + + #[tokio::test] + async fn gateway20_pipeline_decode_failure_is_sent_transport_error() { + let mock = Arc::new(Gateway20MockTransportClient::new(vec![HttpResponse { + status: 200, + headers: azure_core::http::headers::Headers::new(), + body: vec![1, 2, 3], + }])); + let client = AdaptiveTransport::Gateway(mock); + let request = gateway20_transport_request(TransportMode::Gateway20); + let endpoint_key = request.endpoint.endpoint_key(); + let credential = Credential::from(azure_core::credentials::Secret::new("dGVzdA==")); + let user_agent = azure_core::http::headers::HeaderValue::from_static("test-agent"); + let mut diagnostics = gateway20_diagnostics(); + + let result = execute_transport_pipeline( + request, + &gateway20_context( + &client, + endpoint_key, + Some("account".to_owned()), + &credential, + &user_agent, + ), + &mut diagnostics, + ) + .await; + + match result.outcome { + TransportOutcome::TransportError { + status, + request_sent, + .. + } => { + assert_eq!(status, CosmosStatus::TRANSPORT_GENERATED_503); + assert_eq!(request_sent, RequestSentStatus::Sent); + } + other => panic!("expected transport error, got {other:?}"), + } + } + + #[tokio::test] + async fn gateway20_pipeline_outer_502_propagates_unchanged_without_unwrap() { + let mock = Arc::new(Gateway20MockTransportClient::new(vec![HttpResponse { + status: 502, + headers: azure_core::http::headers::Headers::new(), + body: vec![], + }])); + let client = AdaptiveTransport::Gateway(mock.clone()); + let request = gateway20_transport_request(TransportMode::Gateway20); + let endpoint_key = request.endpoint.endpoint_key(); + let credential = Credential::from(azure_core::credentials::Secret::new("dGVzdA==")); + let user_agent = azure_core::http::headers::HeaderValue::from_static("test-agent"); + let mut diagnostics = gateway20_diagnostics(); + + let result = execute_transport_pipeline( + request, + &gateway20_context( + &client, + endpoint_key, + Some("account".to_owned()), + &credential, + &user_agent, + ), + &mut diagnostics, + ) + .await; + + match result.outcome { + TransportOutcome::HttpError { + status, + body, + request_sent, + .. + } => { + assert_eq!(u16::from(status.status_code()), 502); + assert_eq!(status.sub_status(), None); + assert_eq!(body, Vec::::new()); + assert_eq!(request_sent, RequestSentStatus::Sent); + } + other => panic!("expected HTTP error, got {other:?}"), + } + assert_eq!(mock.requests().len(), 1); + } + + #[tokio::test] + async fn gateway20_pipeline_inner_401_surfaces_as_inner_status() { + let mock = Arc::new(Gateway20MockTransportClient::new(vec![gateway20_response( + 401, + |_| {}, + b"", + )])); + let client = AdaptiveTransport::Gateway(mock.clone()); + let request = gateway20_transport_request(TransportMode::Gateway20); + let endpoint_key = request.endpoint.endpoint_key(); + let credential = Credential::from(azure_core::credentials::Secret::new("dGVzdA==")); + let user_agent = azure_core::http::headers::HeaderValue::from_static("test-agent"); + let mut diagnostics = gateway20_diagnostics(); + + let result = execute_transport_pipeline( + request, + &gateway20_context( + &client, + endpoint_key, + Some("account".to_owned()), + &credential, + &user_agent, + ), + &mut diagnostics, + ) + .await; + + match result.outcome { + TransportOutcome::HttpError { + status, + body, + request_sent, + .. + } => { + assert_eq!( + status.status_code(), + azure_core::http::StatusCode::Unauthorized + ); + assert_eq!(status.sub_status(), None); + assert_eq!(body, Vec::::new()); + assert_eq!(request_sent, RequestSentStatus::Sent); + } + other => panic!("expected HTTP error, got {other:?}"), + } + assert_eq!(mock.requests().len(), 1); + } + + #[tokio::test] + async fn gateway20_pipeline_uses_inner_retry_after_for_throttle_retry() { + let mock = Arc::new(Gateway20MockTransportClient::new(vec![ + gateway20_response( + 429, + |bytes| write_gateway20_u32_token(bytes, 0x000C, 0), + b"", + ), + gateway20_response(200, |_| {}, b"{}"), + ])); + let client = AdaptiveTransport::Gateway(mock.clone()); + let request = gateway20_transport_request(TransportMode::Gateway20); + let endpoint_key = request.endpoint.endpoint_key(); + let credential = Credential::from(azure_core::credentials::Secret::new("dGVzdA==")); + let user_agent = azure_core::http::headers::HeaderValue::from_static("test-agent"); + let mut diagnostics = gateway20_diagnostics(); + + let result = execute_transport_pipeline( + request, + &gateway20_context( + &client, + endpoint_key, + Some("account".to_owned()), + &credential, + &user_agent, + ), + &mut diagnostics, + ) + .await; + + assert!(matches!(result.outcome, TransportOutcome::Success { .. })); + assert_eq!(mock.requests().len(), 2); + } + + fn gateway20_response( + status: u32, + write_tokens: impl FnOnce(&mut Vec), + body: &[u8], + ) -> HttpResponse { + let mut bytes = Vec::new(); + bytes.extend_from_slice(&0_u32.to_le_bytes()); + bytes.extend_from_slice(&status.to_le_bytes()); + write_gateway20_uuid( + &mut bytes, + uuid::Uuid::parse_str(GATEWAY20_ACTIVITY_ID).unwrap(), + ); + write_tokens(&mut bytes); + bytes.extend_from_slice(body); + let total_len = u32::try_from(bytes.len()).unwrap(); + bytes[0..4].copy_from_slice(&total_len.to_le_bytes()); + HttpResponse { + status: 200, + headers: azure_core::http::headers::Headers::new(), + body: bytes, + } + } + + fn write_gateway20_u32_token(bytes: &mut Vec, id: u16, value: u32) { + bytes.extend_from_slice(&id.to_le_bytes()); + bytes.push(0x02); + bytes.extend_from_slice(&value.to_le_bytes()); + } + + fn write_gateway20_uuid(bytes: &mut Vec, value: uuid::Uuid) { + let value = value.as_u128(); + bytes.extend_from_slice(&((value >> 64) as u64).to_le_bytes()); + bytes.extend_from_slice(&(value as u64).to_le_bytes()); + } + #[test] fn format_transport_error_details_includes_error_chain() { let inner = std::io::Error::new(std::io::ErrorKind::ConnectionReset, "socket reset"); diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/fault_injection/condition.rs b/sdk/cosmos/azure_data_cosmos_driver/src/fault_injection/condition.rs index 114383144c0..1ecfc3b3eaa 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/src/fault_injection/condition.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/src/fault_injection/condition.rs @@ -4,6 +4,7 @@ //! Defines conditions for when fault injection rules should be applied. use super::FaultOperationType; +use crate::diagnostics::TransportKind; use crate::options::Region; /// Defines the condition under which a fault injection rule should be applied. @@ -13,6 +14,7 @@ pub struct FaultInjectionCondition { operation_type: Option, region: Option, container_id: Option, + transport_kind: Option, } impl FaultInjectionCondition { @@ -30,6 +32,15 @@ impl FaultInjectionCondition { pub fn container_id(&self) -> Option<&str> { self.container_id.as_deref() } + + /// Returns the transport kind to which the fault injection applies. + /// + /// When `Some`, the rule only matches requests that travelled through the + /// specified transport (e.g. `TransportKind::Gateway20`). When `None`, the + /// rule matches every transport (including metadata, gateway, and Gateway 2.0). + pub fn transport_kind(&self) -> Option { + self.transport_kind + } } /// Builder for creating a FaultInjectionCondition. @@ -38,6 +49,7 @@ pub struct FaultInjectionConditionBuilder { operation_type: Option, region: Option, container_id: Option, + transport_kind: Option, } impl FaultInjectionConditionBuilder { @@ -47,6 +59,7 @@ impl FaultInjectionConditionBuilder { operation_type: None, region: None, container_id: None, + transport_kind: None, } } @@ -68,12 +81,24 @@ impl FaultInjectionConditionBuilder { self } + /// Restricts the rule to a specific transport kind. + /// + /// Use this to scope a fault to (for example) only Gateway 2.0 traffic + /// (`TransportKind::Gateway20`) while leaving the standard gateway path + /// untouched. When unset, the rule applies regardless of which transport + /// carried the request. + pub fn with_transport_kind(mut self, transport_kind: TransportKind) -> Self { + self.transport_kind = Some(transport_kind); + self + } + /// Builds the FaultInjectionCondition. pub fn build(self) -> FaultInjectionCondition { FaultInjectionCondition { operation_type: self.operation_type, region: self.region, container_id: self.container_id, + transport_kind: self.transport_kind, } } } @@ -81,6 +106,7 @@ impl FaultInjectionConditionBuilder { #[cfg(test)] mod tests { use super::FaultInjectionConditionBuilder; + use crate::diagnostics::TransportKind; #[test] fn builder_default() { @@ -88,6 +114,14 @@ mod tests { let condition = builder.build(); assert!(condition.operation_type().is_none()); assert!(condition.region().is_none()); - assert!(condition.container_id().is_none()); + assert!(condition.transport_kind().is_none()); + } + + #[test] + fn with_transport_kind_round_trip() { + let condition = FaultInjectionConditionBuilder::new() + .with_transport_kind(TransportKind::Gateway20) + .build(); + assert_eq!(condition.transport_kind(), Some(TransportKind::Gateway20)); } } diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/fault_injection/evaluation.rs b/sdk/cosmos/azure_data_cosmos_driver/src/fault_injection/evaluation.rs index 09f6aaed62e..5dd147e6dbc 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/src/fault_injection/evaluation.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/src/fault_injection/evaluation.rs @@ -61,6 +61,12 @@ pub enum FaultInjectionEvaluation { /// The ID of the rule. rule_id: String, }, + /// Rule was skipped because the request was not carried by the transport + /// kind that the rule restricts itself to. + TransportKindMismatch { + /// The ID of the rule. + rule_id: String, + }, /// Rule matched but was superseded by a higher-priority rule (first-match-wins). Superseded { /// The ID of the superseded rule. @@ -87,6 +93,7 @@ impl FaultInjectionEvaluation { | Self::OperationMismatch { rule_id } | Self::RegionMismatch { rule_id } | Self::ContainerMismatch { rule_id } + | Self::TransportKindMismatch { rule_id } | Self::Superseded { rule_id } => rule_id, } } @@ -136,6 +143,9 @@ impl std::fmt::Display for FaultInjectionEvaluation { Self::ContainerMismatch { rule_id } => { write!(f, "rule '{rule_id}': skipped (container mismatch)") } + Self::TransportKindMismatch { rule_id } => { + write!(f, "rule '{rule_id}': skipped (transport kind mismatch)") + } Self::Superseded { rule_id } => { write!( f, diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/fault_injection/fault_injecting_factory.rs b/sdk/cosmos/azure_data_cosmos_driver/src/fault_injection/fault_injecting_factory.rs index ab94ac509bf..5b9153692b3 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/src/fault_injection/fault_injecting_factory.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/src/fault_injection/fault_injecting_factory.rs @@ -41,9 +41,14 @@ impl HttpClientFactory for FaultInjectingHttpClientFactory { connection_pool: &ConnectionPoolOptions, config: HttpClientConfig, ) -> azure_core::Result> { + let transport_kind = config.transport_kind; let real_client = self.inner.build(connection_pool, config)?; let rules = (*self.rules).clone(); - Ok(Arc::new(FaultClient::new(real_client, rules))) + Ok(Arc::new(FaultClient::new( + real_client, + rules, + transport_kind, + ))) } } diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/fault_injection/http_client.rs b/sdk/cosmos/azure_data_cosmos_driver/src/fault_injection/http_client.rs index 0a3d832309a..279da495f17 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/src/fault_injection/http_client.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/src/fault_injection/http_client.rs @@ -10,7 +10,7 @@ use super::rule::FaultInjectionRule; use super::FaultInjectionErrorType; use super::FaultInjectionEvaluation; use super::FaultOperationType; -use crate::diagnostics::RequestSentStatus; +use crate::diagnostics::{RequestSentStatus, TransportKind}; use crate::driver::transport::cosmos_transport_client::{ HttpRequest, HttpResponse, TransportClient, TransportError, }; @@ -42,6 +42,10 @@ pub struct FaultClient { inner: Arc, /// The fault injection rules to apply. rules: Arc>>, + /// The transport kind this client serves, when bound to a dataplane + /// transport. `None` for metadata clients (account discovery and + /// similar) where the gateway-vs-Gateway-2.0 distinction does not apply. + transport_kind: Option, } impl FaultClient { @@ -49,10 +53,12 @@ impl FaultClient { pub(crate) fn new( inner: Arc, rules: Vec>, + transport_kind: Option, ) -> Self { Self { inner, rules: Arc::new(rules), + transport_kind, } } @@ -137,6 +143,18 @@ impl FaultClient { } } + if let Some(expected_kind) = condition.transport_kind() { + // The rule restricts itself to a specific transport. If this + // FaultClient is bound to a different transport (or to a + // metadata client with no transport kind at all), the rule + // does not apply. + if self.transport_kind != Some(expected_kind) { + return Some(FaultInjectionEvaluation::TransportKindMismatch { + rule_id: rule.id().to_owned(), + }); + } + } + None // Condition matches } @@ -377,6 +395,7 @@ impl TransportClient for FaultClient { #[cfg(test)] mod tests { use super::FaultClient; + use crate::diagnostics::TransportKind; use crate::driver::transport::cosmos_transport_client::{ HttpRequest, HttpResponse, TransportClient, TransportError, }; @@ -459,7 +478,7 @@ mod tests { .with_condition(condition) .build(); - let fault_client = FaultClient::new(mock_client.clone(), vec![Arc::new(rule)]); + let fault_client = FaultClient::new(mock_client.clone(), vec![Arc::new(rule)], None); // Request without operation type header shouldn't match let (request, _collector) = create_test_request(); @@ -472,7 +491,7 @@ mod tests { #[tokio::test] async fn execute_request_empty_rules() { let mock_client = Arc::new(MockTransportClient::new()); - let fault_client = FaultClient::new(mock_client.clone(), vec![]); + let fault_client = FaultClient::new(mock_client.clone(), vec![], None); let (request, _collector) = create_test_request(); let result = fault_client.send(&request).await; @@ -492,7 +511,7 @@ mod tests { .with_hit_limit(2) .build(); - let fault_client = FaultClient::new(mock_client.clone(), vec![Arc::new(rule)]); + let fault_client = FaultClient::new(mock_client.clone(), vec![Arc::new(rule)], None); let (request, _collector) = create_test_request(); // First two requests should hit the fault @@ -519,7 +538,7 @@ mod tests { .with_start_time(Instant::now() + Duration::from_secs(60)) .build(); - let fault_client = FaultClient::new(mock_client.clone(), vec![Arc::new(rule)]); + let fault_client = FaultClient::new(mock_client.clone(), vec![Arc::new(rule)], None); let (request, _collector) = create_test_request(); // Request should pass through because start_time is in the future @@ -537,7 +556,7 @@ mod tests { .build(); let rule = FaultInjectionRuleBuilder::new("error-rule", error).build(); - let fault_client = FaultClient::new(mock_client.clone(), vec![Arc::new(rule)]); + let fault_client = FaultClient::new(mock_client.clone(), vec![Arc::new(rule)], None); let (request, _collector) = create_test_request(); let result = fault_client.send(&request).await; @@ -562,7 +581,7 @@ mod tests { .build(); let rule = FaultInjectionRuleBuilder::new("throttle-rule", error).build(); - let fault_client = FaultClient::new(mock_client.clone(), vec![Arc::new(rule)]); + let fault_client = FaultClient::new(mock_client.clone(), vec![Arc::new(rule)], None); let (request, _collector) = create_test_request(); let result = fault_client.send(&request).await; @@ -586,7 +605,7 @@ mod tests { .build(); let rule = FaultInjectionRuleBuilder::new("response-delay-rule", error).build(); - let fault_client = FaultClient::new(mock_client.clone(), vec![Arc::new(rule)]); + let fault_client = FaultClient::new(mock_client.clone(), vec![Arc::new(rule)], None); let (request, _collector) = create_test_request(); // Delay-only should pass through to actual request after delay @@ -619,7 +638,7 @@ mod tests { .with_condition(condition) .build(); - let fault_client = FaultClient::new(mock_client.clone(), vec![Arc::new(rule)]); + let fault_client = FaultClient::new(mock_client.clone(), vec![Arc::new(rule)], None); // Request URL doesn't contain "westus", should pass through let (request, _collector) = create_test_request(); @@ -643,7 +662,7 @@ mod tests { .with_condition(condition) .build(); - let fault_client = FaultClient::new(mock_client.clone(), vec![Arc::new(rule)]); + let fault_client = FaultClient::new(mock_client.clone(), vec![Arc::new(rule)], None); // Request URL doesn't contain "my-container", should pass through let (request, _collector) = create_test_request(); @@ -665,7 +684,7 @@ mod tests { .with_hit_limit(2) .build(); - let fault_client = FaultClient::new(mock_client.clone(), vec![Arc::new(rule)]); + let fault_client = FaultClient::new(mock_client.clone(), vec![Arc::new(rule)], None); let (request, _collector) = create_test_request(); // First request should hit the fault @@ -726,7 +745,7 @@ mod tests { .build(); let rule = FaultInjectionRuleBuilder::new("substatus-rule", error).build(); - let fault_client = FaultClient::new(mock_client, vec![Arc::new(rule)]); + let fault_client = FaultClient::new(mock_client, vec![Arc::new(rule)], None); let (request, _collector) = create_test_request(); let result = fault_client.send(&request).await; @@ -783,7 +802,7 @@ mod tests { .build(); let rule = FaultInjectionRuleBuilder::new("conn-error", error).build(); - let fault_client = FaultClient::new(mock_client.clone(), vec![Arc::new(rule)]); + let fault_client = FaultClient::new(mock_client.clone(), vec![Arc::new(rule)], None); let (request, _collector) = create_test_request(); let result = fault_client.send(&request).await; @@ -807,7 +826,7 @@ mod tests { .build(); let rule = FaultInjectionRuleBuilder::new("timeout-error", error).build(); - let fault_client = FaultClient::new(mock_client.clone(), vec![Arc::new(rule)]); + let fault_client = FaultClient::new(mock_client.clone(), vec![Arc::new(rule)], None); let (request, _collector) = create_test_request(); let result = fault_client.send(&request).await; @@ -836,7 +855,7 @@ mod tests { .build(); let rule = FaultInjectionRuleBuilder::new("custom-response-rule", result).build(); - let fault_client = FaultClient::new(mock_client.clone(), vec![Arc::new(rule)]); + let fault_client = FaultClient::new(mock_client.clone(), vec![Arc::new(rule)], None); let (request, _collector) = create_test_request(); let response = fault_client.send(&request).await; @@ -862,7 +881,7 @@ mod tests { .with_condition(condition) .build(); - let fault_client = FaultClient::new(mock_client.clone(), vec![Arc::new(rule)]); + let fault_client = FaultClient::new(mock_client.clone(), vec![Arc::new(rule)], None); let (mut request, _collector) = create_test_request(); request @@ -890,7 +909,7 @@ mod tests { .build(); let rule = FaultInjectionRuleBuilder::new("header-test-rule", result).build(); - let fault_client = FaultClient::new(mock_client, vec![Arc::new(rule)]); + let fault_client = FaultClient::new(mock_client, vec![Arc::new(rule)], None); let (request, collector) = create_test_request(); let response = fault_client.send(&request).await; @@ -911,7 +930,7 @@ mod tests { let rule = Arc::new(FaultInjectionRuleBuilder::new("disabled-rule", error).build()); rule.disable(); - let fault_client = FaultClient::new(mock_client, vec![rule]); + let fault_client = FaultClient::new(mock_client, vec![rule], None); let (request, collector) = create_test_request(); let result = fault_client.send(&request).await; assert!(result.is_ok(), "Request should succeed with disabled rule"); @@ -937,7 +956,7 @@ mod tests { .with_error(FaultInjectionErrorType::ServiceUnavailable) .build(); let rule = FaultInjectionRuleBuilder::new("test-rule", error).build(); - let fault_client = FaultClient::new(mock_client, vec![Arc::new(rule)]); + let fault_client = FaultClient::new(mock_client, vec![Arc::new(rule)], None); let (request, collector) = create_test_request(); let _ = fault_client.send(&request).await; @@ -955,7 +974,7 @@ mod tests { .with_error(FaultInjectionErrorType::ConnectionError) .build(); let rule = FaultInjectionRuleBuilder::new("conn-rule", error).build(); - let fault_client = FaultClient::new(mock_client, vec![Arc::new(rule)]); + let fault_client = FaultClient::new(mock_client, vec![Arc::new(rule)], None); let (request, collector) = create_test_request(); let _ = fault_client.send(&request).await; @@ -973,7 +992,7 @@ mod tests { .with_error(FaultInjectionErrorType::ResponseTimeout) .build(); let rule = FaultInjectionRuleBuilder::new("timeout-rule", error).build(); - let fault_client = FaultClient::new(mock_client, vec![Arc::new(rule)]); + let fault_client = FaultClient::new(mock_client, vec![Arc::new(rule)], None); let (request, collector) = create_test_request(); let _ = fault_client.send(&request).await; @@ -1003,7 +1022,7 @@ mod tests { let rule2 = Arc::new(FaultInjectionRuleBuilder::new("active-rule", error2).build()); let rule3 = Arc::new(FaultInjectionRuleBuilder::new("superseded-rule", error3).build()); - let fault_client = FaultClient::new(mock_client, vec![rule1, rule2, rule3]); + let fault_client = FaultClient::new(mock_client, vec![rule1, rule2, rule3], None); let (request, collector) = create_test_request(); let _ = fault_client.send(&request).await; @@ -1039,7 +1058,7 @@ mod tests { .with_condition(condition) .build(); - let fault_client = FaultClient::new(mock_client, vec![Arc::new(rule)]); + let fault_client = FaultClient::new(mock_client, vec![Arc::new(rule)], None); // Request without matching operation header let (request, collector) = create_test_request(); @@ -1052,4 +1071,102 @@ mod tests { super::FaultInjectionEvaluation::OperationMismatch { rule_id } if rule_id == "no-match-rule" )); } + + #[tokio::test] + async fn transport_kind_filter_skips_when_kind_does_not_match() { + let mock_client = Arc::new(MockTransportClient::new()); + + // Rule scoped to Gateway 2.0 only. + let condition = FaultInjectionConditionBuilder::new() + .with_transport_kind(TransportKind::Gateway20) + .build(); + let error = FaultInjectionResultBuilder::new() + .with_error(FaultInjectionErrorType::ServiceUnavailable) + .build(); + let rule = FaultInjectionRuleBuilder::new("gw20-only", error) + .with_condition(condition) + .build(); + + // Bind the FaultClient to a non-Gateway-2.0 transport — the rule + // must be skipped and the request must reach the inner client. + let fault_client = FaultClient::new( + mock_client.clone(), + vec![Arc::new(rule)], + Some(TransportKind::Gateway), + ); + + let (request, collector) = create_test_request(); + let result = fault_client.send(&request).await; + + assert!(result.is_ok()); + assert_eq!(mock_client.call_count(), 1); + + let evals = collector.take(); + assert_eq!(evals.len(), 1); + assert!(matches!( + &evals[0], + super::FaultInjectionEvaluation::TransportKindMismatch { rule_id } if rule_id == "gw20-only" + )); + } + + #[tokio::test] + async fn transport_kind_filter_applies_when_kind_matches() { + let mock_client = Arc::new(MockTransportClient::new()); + + let condition = FaultInjectionConditionBuilder::new() + .with_transport_kind(TransportKind::Gateway20) + .build(); + let error = FaultInjectionResultBuilder::new() + .with_error(FaultInjectionErrorType::ServiceUnavailable) + .build(); + let rule = FaultInjectionRuleBuilder::new("gw20-only", error) + .with_condition(condition) + .build(); + + // Bind to a Gateway 2.0 transport — the rule must apply and the + // injected error must surface to the caller. + let fault_client = FaultClient::new( + mock_client.clone(), + vec![Arc::new(rule)], + Some(TransportKind::Gateway20), + ); + + let (request, _collector) = create_test_request(); + let result = fault_client.send(&request).await; + + assert!(result.is_err()); + // Inner client must NOT have been called when a fault is injected. + assert_eq!(mock_client.call_count(), 0); + } + + #[tokio::test] + async fn transport_kind_filter_skips_metadata_clients() { + let mock_client = Arc::new(MockTransportClient::new()); + + let condition = FaultInjectionConditionBuilder::new() + .with_transport_kind(TransportKind::Gateway20) + .build(); + let error = FaultInjectionResultBuilder::new() + .with_error(FaultInjectionErrorType::ServiceUnavailable) + .build(); + let rule = FaultInjectionRuleBuilder::new("gw20-only", error) + .with_condition(condition) + .build(); + + // Metadata clients have transport_kind = None. A rule that + // requires a specific transport must never apply to metadata. + let fault_client = FaultClient::new(mock_client.clone(), vec![Arc::new(rule)], None); + + let (request, collector) = create_test_request(); + let result = fault_client.send(&request).await; + + assert!(result.is_ok()); + assert_eq!(mock_client.call_count(), 1); + + let evals = collector.take(); + assert!(matches!( + evals.as_slice(), + [super::FaultInjectionEvaluation::TransportKindMismatch { rule_id }] if rule_id == "gw20-only" + )); + } } diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/fault_injection/rule.rs b/sdk/cosmos/azure_data_cosmos_driver/src/fault_injection/rule.rs index b600b008e83..45e8d2164aa 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/src/fault_injection/rule.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/src/fault_injection/rule.rs @@ -61,7 +61,11 @@ impl FaultInjectionRule { } /// Increments the hit count by one. - pub(crate) fn increment_hit_count(&self) { + /// + /// This is intended to be called by fault-injection HTTP clients (in the + /// driver and in the Cosmos SDK) when they decide to apply this rule to + /// an in-flight request, so that `hit_limit` can be honoured. + pub fn increment_hit_count(&self) { self.hit_count.fetch_add(1, Ordering::SeqCst); } diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/lib.rs b/sdk/cosmos/azure_data_cosmos_driver/src/lib.rs index 978f929bdc3..15c0537d577 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/src/lib.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/src/lib.rs @@ -20,6 +20,7 @@ //! raw bytes (`&[u8]`) and return buffered responses (`Vec`). Serialization is handled by //! the consuming SDK in its native language. +pub mod constants; pub mod diagnostics; pub mod driver; #[cfg(feature = "fault_injection")] diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/models/account_reference.rs b/sdk/cosmos/azure_data_cosmos_driver/src/models/account_reference.rs index 5f32b3d9ba9..0237847e26f 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/src/models/account_reference.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/src/models/account_reference.rs @@ -34,6 +34,35 @@ impl AccountEndpoint { self.0.host_str().unwrap_or("") } + /// Returns the global database account name parsed from the endpoint hostname's first label. + /// + /// Returns `None` for emulator, IP literal, and custom-domain hosts. The parsed value is used + /// as the RNTBD `GlobalDatabaseAccountName` metadata token on Gateway 2.0 requests; when it + /// cannot be parsed, Gateway 2.0 requests fall back to standard Gateway for that account. + // Slice 3 reads this value when Gateway 2.0 dispatch is wired in. + #[allow(dead_code)] + pub(crate) fn global_database_account_name(&self) -> Option { + let host = self.host(); + if host.is_empty() { + return None; + } + + if host.starts_with(|c: char| c.is_ascii_digit()) || host.contains(':') { + return None; + } + + let (label, suffix) = host.split_once('.')?; + if label.is_empty() || suffix.is_empty() { + return None; + } + + if !suffix.starts_with("documents.") { + return None; + } + + Some(label.to_owned()) + } + /// Joins a resource path to this endpoint to create a full request URL. /// /// The path should be the resource path (e.g., "/dbs/mydb/colls/mycoll"). @@ -375,6 +404,33 @@ mod tests { assert_eq!(endpoint.host(), "myaccount.documents.azure.com"); } + #[test] + fn global_database_account_name_extracts_only_cosmos_hosts() { + let cases = [ + ("https://myaccount.documents.azure.com/", Some("myaccount")), + ( + "https://my-account-123.documents.azure.com/", + Some("my-account-123"), + ), + ("https://myacct.documents.azure.us/", Some("myacct")), + ("https://myacct.documents.azure.cn:443/", Some("myacct")), + ("https://localhost:8081/", None), + ("https://127.0.0.1:8081/", None), + ("https://[::1]:8081/", None), + ("https://my.custom.domain/", None), + ("https://example.com/", None), + ]; + + for (url, expected) in cases { + let endpoint = AccountEndpoint::try_from(url).unwrap(); + assert_eq!( + endpoint.global_database_account_name().as_deref(), + expected, + "unexpected account name for {url}" + ); + } + } + #[test] fn builder_with_master_key() { let account = diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/models/cosmos_status.rs b/sdk/cosmos/azure_data_cosmos_driver/src/models/cosmos_status.rs index 1b11373585d..6150a532521 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/src/models/cosmos_status.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/src/models/cosmos_status.rs @@ -979,6 +979,9 @@ impl SubStatusCode { /// Transport generated 503 (20003). pub const TRANSPORT_GENERATED_503: SubStatusCode = SubStatusCode(20003); + /// Client generated 400 — request wrapping failure (20400). + pub const CLIENT_GENERATED_400: SubStatusCode = SubStatusCode(20400); + /// Client generated 401 — authorization/signing failure (20401). pub const CLIENT_GENERATED_401: SubStatusCode = SubStatusCode(20401); @@ -1336,6 +1339,15 @@ impl CosmosStatus { sub_status: Some(SubStatusCode::TRANSPORT_GENERATED_503), }; + /// Client-generated 400 Bad Request (sub-status 20400). + /// + /// Generated by the SDK when Gateway 2.0 request wrapping fails before + /// the request is sent. + pub const CLIENT_GENERATED_400: CosmosStatus = CosmosStatus { + status_code: StatusCode::BadRequest, + sub_status: Some(SubStatusCode::CLIENT_GENERATED_400), + }; + /// Client-generated 401 Unauthorized (sub-status 20401). /// /// Generated by the SDK when request signing/authorization fails before @@ -1814,6 +1826,8 @@ mod tests { assert_eq!(SubStatusCode::TRANSPORT_GENERATED_410.value(), 20001); assert_eq!(SubStatusCode::TIMEOUT_GENERATED_410.value(), 20002); assert_eq!(SubStatusCode::TRANSPORT_GENERATED_503.value(), 20003); + assert_eq!(SubStatusCode::CLIENT_GENERATED_400.value(), 20400); + assert_eq!(SubStatusCode::CLIENT_GENERATED_401.value(), 20401); assert_eq!(SubStatusCode::CLIENT_CPU_OVERLOAD.value(), 20004); assert_eq!(SubStatusCode::CLIENT_THREAD_STARVATION.value(), 20005); assert_eq!(SubStatusCode::CLIENT_OPERATION_TIMEOUT.value(), 20008); diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/options/connection_pool.rs b/sdk/cosmos/azure_data_cosmos_driver/src/options/connection_pool.rs index 8bc79699cfd..74e11a455ee 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/src/options/connection_pool.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/src/options/connection_pool.rs @@ -63,7 +63,7 @@ pub struct ConnectionPoolOptions { is_http2_allowed: bool, - is_gateway20_allowed: bool, + gateway20_disabled: bool, emulator_server_cert_validation: EmulatorServerCertValidation, @@ -206,13 +206,17 @@ impl ConnectionPoolOptions { self.is_http2_allowed } - /// Returns whether Gateway 2.0 feature is allowed. + /// Returns whether Gateway 2.0 is disabled for this pool. /// - /// If `true`, the driver will use Gateway 2.0 features when communicating - /// with the Cosmos DB service (if the account supports it). Gateway 2.0 - /// requires HTTP/2, so this returns `false` if HTTP/2 is disabled. - pub fn is_gateway20_allowed(&self) -> bool { - self.is_gateway20_allowed + /// Gateway 2.0 is enabled by default whenever the account advertises a + /// Gateway 2.0 endpoint and HTTP/2 is allowed. When this method returns + /// `true` the driver routes every request through the standard gateway + /// transport, regardless of the account advertisement. + /// + /// Gateway 2.0 also requires HTTP/2: when HTTP/2 is disabled, this method + /// returns `true` regardless of how the builder was configured. + pub fn gateway20_disabled(&self) -> bool { + self.gateway20_disabled } /// Returns the emulator server certificate validation setting. @@ -257,8 +261,12 @@ impl ConnectionPoolOptions { /// - `AZURE_COSMOS_CONNECTION_POOL_TCP_KEEPALIVE_INTERVAL_MS`: TCP keepalive probe interval in milliseconds (default: `1_000`, min: `1_000` when set) /// - `AZURE_COSMOS_CONNECTION_POOL_TCP_KEEPALIVE_RETRIES`: TCP keepalive retry count (default: none, min: `1`, max: `255`) /// - `AZURE_COSMOS_CONNECTION_POOL_IS_HTTP2_ALLOWED`: Whether HTTP/2 is allowed for gateway mode connections (default: `true`) -/// - `AZURE_COSMOS_CONNECTION_POOL_IS_GATEWAY20_ALLOWED`: Whether Gateway 2.0 feature is allowed (default: `false`) /// - `AZURE_COSMOS_EMULATOR_SERVER_CERT_VALIDATION_DISABLED`: Whether server certificate validation is disabled for emulator; `true` maps to [`EmulatorServerCertValidation::DangerousDisabled`], `false` to [`EmulatorServerCertValidation::Enabled`] (default: `false`) +/// +/// Gateway 2.0 is intentionally **not** controlled by an environment variable +/// (see `GATEWAY_20_SPEC.md` §3): the only supported disablement mechanism is +/// the [`with_gateway20_disabled`](ConnectionPoolOptionsBuilder::with_gateway20_disabled) +/// builder method. /// - `AZURE_COSMOS_LOCAL_ADDRESS`: Local IP address to bind to (default: none) /// /// # Example @@ -296,7 +304,7 @@ pub struct ConnectionPoolOptionsBuilder { tcp_keepalive_interval: Option, tcp_keepalive_retries: Option, is_http2_allowed: Option, - is_gateway20_allowed: Option, + gateway20_disabled: Option, emulator_server_cert_validation: Option, local_address: Option, } @@ -490,9 +498,25 @@ impl ConnectionPoolOptionsBuilder { self } - /// Sets whether Gateway 2.0 feature is allowed. - pub fn with_is_gateway20_allowed(mut self, value: bool) -> Self { - self.is_gateway20_allowed = Some(value); + /// Disables Gateway 2.0 for this pool. + /// + /// Gateway 2.0 is enabled by default whenever the account advertises a + /// Gateway 2.0 endpoint and HTTP/2 is allowed. Pass `true` to force every + /// request through the standard gateway transport regardless of the + /// account advertisement (operator override). + /// + /// There is intentionally no `AZURE_COSMOS_*` environment variable that + /// toggles Gateway 2.0 — the override must be applied programmatically + /// via this method. + /// + /// # Latency caveat + /// + /// Gateway 2.0 traffic flows through a Gateway 2.0 proxy that is **not + /// currently covered by the regional Cosmos DB latency SLA**. Workloads + /// with strict P99 latency requirements should call this method with + /// `true` until the proxy reaches general availability. + pub fn with_gateway20_disabled(mut self, value: bool) -> Self { + self.gateway20_disabled = Some(value); self } @@ -532,25 +556,16 @@ impl ConnectionPoolOptionsBuilder { ValidationBounds::none(), )?; - let effective_is_gateway20_allowed = if let Some(gateway20) = self.is_gateway20_allowed { - gateway20 && effective_is_http2_allowed - } else { - match std::env::var("AZURE_COSMOS_CONNECTION_POOL_IS_GATEWAY20_ALLOWED") { - Ok(v) => { - let gateway20: bool = v.parse().map_err(|e| { - azure_core::Error::with_message( - azure_core::error::ErrorKind::DataConversion, - format!( - "Failed to parse AZURE_COSMOS_CONNECTION_POOL_IS_GATEWAY20_ALLOWED as boolean: {} ({})", - v, e - ), - ) - })?; - gateway20 && effective_is_http2_allowed - } - Err(_) => false, // TODO: Change to true before GA - } - }; + // Gateway 2.0 is enabled by default whenever HTTP/2 is allowed and + // the account advertises a Gateway 2.0 endpoint. The flag uses a + // negative-term name so that the absence of an opt-in is the on + // state; operators disable Gateway 2.0 by setting this to `true`. + // There is intentionally no `AZURE_COSMOS_*` env var that toggles + // Gateway 2.0 — the override must be applied programmatically. + let explicit_disabled = self.gateway20_disabled.unwrap_or(false); + // HTTP/2 is a hard prerequisite for Gateway 2.0 — when HTTP/2 is off + // the pool is effectively gateway20-disabled regardless of the flag. + let effective_gateway20_disabled = explicit_disabled || !effective_is_http2_allowed; let max_connection_pool_size_default = if effective_is_http2_allowed { 1_000 @@ -765,7 +780,7 @@ impl ConnectionPoolOptionsBuilder { tcp_keepalive_interval, tcp_keepalive_retries, is_http2_allowed: effective_is_http2_allowed, - is_gateway20_allowed: effective_is_gateway20_allowed, + gateway20_disabled: effective_gateway20_disabled, emulator_server_cert_validation: match self.emulator_server_cert_validation { Some(v) => v, None => EmulatorServerCertValidation::from(parse_from_env( @@ -822,7 +837,8 @@ mod tests { Duration::from_millis(65_000) ); assert!(options.is_http2_allowed()); - assert!(!options.is_gateway20_allowed()); + // Gateway 2.0 is enabled by default whenever HTTP/2 is allowed. + assert!(!options.gateway20_disabled()); assert_eq!( options.emulator_server_cert_validation(), EmulatorServerCertValidation::Enabled @@ -879,7 +895,7 @@ mod tests { .with_tcp_keepalive_interval(Duration::from_millis(5_000)) .with_tcp_keepalive_retries(4) .with_is_http2_allowed(false) - .with_is_gateway20_allowed(true) + .with_gateway20_disabled(false) .with_emulator_server_cert_validation(EmulatorServerCertValidation::DangerousDisabled) .build() .unwrap(); @@ -942,8 +958,9 @@ mod tests { ); assert_eq!(options.tcp_keepalive_retries(), Some(4)); assert!(!options.is_http2_allowed()); - // gateway20 is set to true but HTTP/2 is false, so it should be false - assert!(!options.is_gateway20_allowed()); + // gateway20 was opted in via with_gateway20_disabled(false), but HTTP/2 is + // off, so the build forces gateway20_disabled = true. + assert!(options.gateway20_disabled()); assert_eq!( options.emulator_server_cert_validation(), EmulatorServerCertValidation::DangerousDisabled @@ -1215,12 +1232,13 @@ mod tests { fn gateway20_requires_http2() { let options = ConnectionPoolOptionsBuilder::new() .with_is_http2_allowed(false) - .with_is_gateway20_allowed(true) + .with_gateway20_disabled(false) .build() .unwrap(); - // Gateway 2.0 should be disabled if HTTP/2 is not allowed - assert!(!options.is_gateway20_allowed()); + // Gateway 2.0 must be reported as disabled if HTTP/2 is not allowed, + // even when the operator explicitly opted in via with_gateway20_disabled(false). + assert!(options.gateway20_disabled()); } #[test] diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/options/mod.rs b/sdk/cosmos/azure_data_cosmos_driver/src/options/mod.rs index 559ae2dab7b..a95dce42fef 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/src/options/mod.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/src/options/mod.rs @@ -40,6 +40,7 @@ pub use policies::{ ExcludedRegions, }; pub use priority::PriorityLevel; +pub(crate) use read_consistency::resolve_effective_consistency; pub use read_consistency::ReadConsistencyStrategy; pub use region::Region; pub use throughput_control::ThroughputControlGroupOptions; diff --git a/sdk/cosmos/azure_data_cosmos_driver/src/options/read_consistency.rs b/sdk/cosmos/azure_data_cosmos_driver/src/options/read_consistency.rs index 391f92515f2..33d8f175cb2 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/src/options/read_consistency.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/src/options/read_consistency.rs @@ -98,6 +98,19 @@ impl ReadConsistencyStrategy { } } +/// Resolves the effective consistency level for a read consistency strategy. +pub(crate) fn resolve_effective_consistency( + strategy: ReadConsistencyStrategy, + account_default: DefaultConsistencyLevel, +) -> DefaultConsistencyLevel { + match strategy { + ReadConsistencyStrategy::Default => account_default, + ReadConsistencyStrategy::Eventual => DefaultConsistencyLevel::Eventual, + ReadConsistencyStrategy::Session => DefaultConsistencyLevel::Session, + ReadConsistencyStrategy::GlobalStrong => DefaultConsistencyLevel::Strong, + } +} + impl std::fmt::Display for ReadConsistencyStrategy { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_str(self.as_str()) @@ -204,4 +217,37 @@ mod tests { assert!(!ReadConsistencyStrategy::GlobalStrong .is_session_effective(DefaultConsistencyLevel::Session)); } + + #[test] + fn resolve_effective_consistency_table() { + let account_defaults = [ + DefaultConsistencyLevel::Strong, + DefaultConsistencyLevel::BoundedStaleness, + DefaultConsistencyLevel::Session, + DefaultConsistencyLevel::ConsistentPrefix, + DefaultConsistencyLevel::Eventual, + ]; + + for account_default in account_defaults { + assert_eq!( + resolve_effective_consistency(ReadConsistencyStrategy::Default, account_default), + account_default + ); + assert_eq!( + resolve_effective_consistency(ReadConsistencyStrategy::Eventual, account_default), + DefaultConsistencyLevel::Eventual + ); + assert_eq!( + resolve_effective_consistency(ReadConsistencyStrategy::Session, account_default), + DefaultConsistencyLevel::Session + ); + assert_eq!( + resolve_effective_consistency( + ReadConsistencyStrategy::GlobalStrong, + account_default + ), + DefaultConsistencyLevel::Strong + ); + } + } } diff --git a/sdk/cosmos/azure_data_cosmos_driver/tests/emulator_tests/driver_fault_injection.rs b/sdk/cosmos/azure_data_cosmos_driver/tests/emulator_tests/driver_fault_injection.rs index 439377273ab..e1c2770962a 100644 --- a/sdk/cosmos/azure_data_cosmos_driver/tests/emulator_tests/driver_fault_injection.rs +++ b/sdk/cosmos/azure_data_cosmos_driver/tests/emulator_tests/driver_fault_injection.rs @@ -6,6 +6,7 @@ #![cfg(feature = "fault_injection")] use crate::framework::DriverTestClient; +use azure_data_cosmos_driver::diagnostics::TransportKind; use azure_data_cosmos_driver::fault_injection::*; use std::error::Error; use std::sync::Arc; @@ -331,3 +332,199 @@ pub async fn fault_injection_connection_error() -> Result<(), Box> { }) .await } + +// ---------------------------------------------------------------------------- +// Gateway 2.0 fault injection coverage (Phase 6) +// ---------------------------------------------------------------------------- +// +// The following three tests lock in the retry/failover behavior the Gateway +// 2.0 transport must exhibit when the underlying Gateway 2.0 connection fails. +// Each test exercises a distinct failure shape: +// +// - 503 Service Unavailable → regional failover +// - 408 Request Timeout → cross-region for reads / local-only for writes +// - 404/1002 Read Session → remote-preferred + no PKRange refresh +// +// **Limitation**: `FaultInjectionCondition` does not yet expose a per-transport- +// kind filter — there is no `with_transport_kind(TransportKind::Gateway20)` +// today. As a result, faults injected here apply to whichever transport happens +// to be selected at dispatch time. To reliably exercise these against Gateway +// 2.0, the Phase 6 CI matrix must run them on a live Gateway 2.0 account +// (`testCategory = 'gateway20'`); the emulator does not yet expose Gateway +// 2.0 endpoints. See `docs/GATEWAY_20_SPEC.md` (Phase 6) for the harness gap. + +/// Gateway 2.0 503 Service Unavailable should trigger regional failover. +/// +/// The rule is scoped to [`TransportKind::Gateway20`] so it does not also +/// fire on standard-gateway requests issued during account discovery. The +/// emulator does not yet expose Gateway 2.0 endpoints, so this test is +/// gated behind the `gateway20` test category until CI gains a Gateway 2.0 +/// account; see `docs/GATEWAY_20_SPEC.md` (Phase 6). +#[tokio::test] +#[cfg_attr( + not(test_category = "gateway20"), + ignore = "requires test_category 'gateway20'" +)] +pub async fn gateway20_service_unavailable_triggers_regional_failover() -> Result<(), Box> +{ + let condition = FaultInjectionConditionBuilder::new() + .with_operation_type(FaultOperationType::ReadItem) + .with_transport_kind(TransportKind::Gateway20) + .build(); + + let result = FaultInjectionResultBuilder::new() + .with_error(FaultInjectionErrorType::ServiceUnavailable) + .with_probability(1.0) + .build(); + + let rule = Arc::new( + FaultInjectionRuleBuilder::new("gateway20-503-failover", result) + .with_condition(condition) + .build(), + ); + let rules = vec![Arc::clone(&rule)]; + + DriverTestClient::run_with_unique_db_and_fault_injection(rules, async |context, database| { + let container_name = context.unique_container_name(); + let container = context + .create_container(&database, &container_name, "/pk") + .await?; + + let item_json = br#"{"id": "item1", "pk": "pk1", "value": "test"}"#; + context + .create_item(&container, "item1", "pk1", item_json) + .await?; + + // The read should fail (single region, fault always fires) but the + // failover machinery must have been invoked. Once `RequestDiagnostics` + // exposes per-attempt endpoint selection, assert that the diagnostics + // record at least one regional failover attempt. + let read_result = context.read_item(&container, "item1", "pk1").await; + assert!( + read_result.is_err(), + "Read should fail when 503 fires on every attempt" + ); + + assert!(rule.hit_count() > 0, "Rule should have been hit"); + + Ok(()) + }) + .await +} + +/// Gateway 2.0 408 Request Timeout should retry across regions for reads, +/// but stay local-only for writes (single-region writes can't safely retry +/// across regions without risking duplicates). +/// +/// The rule is scoped to [`TransportKind::Gateway20`] so it does not affect +/// standard-gateway traffic. The emulator does not yet expose Gateway 2.0 +/// endpoints, so this test is gated behind the `gateway20` test category. +#[tokio::test] +#[cfg_attr( + not(test_category = "gateway20"), + ignore = "requires test_category 'gateway20'" +)] +pub async fn gateway20_request_timeout_cross_region_for_reads() -> Result<(), Box> { + let condition = FaultInjectionConditionBuilder::new() + .with_operation_type(FaultOperationType::ReadItem) + .with_transport_kind(TransportKind::Gateway20) + .build(); + + let result = FaultInjectionResultBuilder::new() + .with_error(FaultInjectionErrorType::Timeout) + .with_probability(1.0) + .build(); + + let rule = Arc::new( + FaultInjectionRuleBuilder::new("gateway20-408-cross-region", result) + .with_condition(condition) + .build(), + ); + let rules = vec![Arc::clone(&rule)]; + + DriverTestClient::run_with_unique_db_and_fault_injection(rules, async |context, database| { + let container_name = context.unique_container_name(); + let container = context + .create_container(&database, &container_name, "/pk") + .await?; + + let item_json = br#"{"id": "item1", "pk": "pk1", "value": "test"}"#; + context + .create_item(&container, "item1", "pk1", item_json) + .await?; + + let read_result = context.read_item(&container, "item1", "pk1").await; + assert!( + read_result.is_err(), + "Read should ultimately fail when 408 fires on every attempt" + ); + + // TODO(Phase 6): once diagnostics expose retry attempts, assert that + // a single-region account exhausts local-only retries while a + // multi-region account performs at least one cross-region attempt. + assert!(rule.hit_count() > 0, "Rule should have been hit"); + + Ok(()) + }) + .await +} + +/// Gateway 2.0 404/1002 ReadSessionNotAvailable must trigger a +/// remote-preferred retry path **without** invalidating the partition-key +/// range (PKRange) cache. The 404/1002 substatus indicates a session-token +/// mismatch, which is unrelated to the routing topology — refreshing PKRange +/// would be a wasted metadata round-trip. +/// +/// The rule is scoped to [`TransportKind::Gateway20`] so it does not also +/// fire on standard-gateway requests. The emulator does not yet expose +/// Gateway 2.0 endpoints, so this test is gated behind the `gateway20` +/// test category until CI gains a Gateway 2.0 account. +#[tokio::test] +#[cfg_attr( + not(test_category = "gateway20"), + ignore = "requires test_category 'gateway20'" +)] +pub async fn gateway20_read_session_not_available_remote_preferred() -> Result<(), Box> { + let condition = FaultInjectionConditionBuilder::new() + .with_operation_type(FaultOperationType::ReadItem) + .with_transport_kind(TransportKind::Gateway20) + .build(); + + let result = FaultInjectionResultBuilder::new() + .with_error(FaultInjectionErrorType::ReadSessionNotAvailable) + .with_probability(1.0) + .build(); + + let rule = Arc::new( + FaultInjectionRuleBuilder::new("gateway20-1002-remote-preferred", result) + .with_condition(condition) + .build(), + ); + let rules = vec![Arc::clone(&rule)]; + + DriverTestClient::run_with_unique_db_and_fault_injection(rules, async |context, database| { + let container_name = context.unique_container_name(); + let container = context + .create_container(&database, &container_name, "/pk") + .await?; + + let item_json = br#"{"id": "item1", "pk": "pk1", "value": "test"}"#; + context + .create_item(&container, "item1", "pk1", item_json) + .await?; + + let read_result = context.read_item(&container, "item1", "pk1").await; + assert!( + read_result.is_err(), + "Read should fail when 404/1002 fires on every attempt" + ); + + // TODO(Phase 6): once diagnostics record metadata-cache hits, assert + // that the PKRange cache was NOT refreshed during these retries (a + // 404/1002 is a session-token issue, not a routing-topology issue). + assert!(rule.hit_count() > 0, "Rule should have been hit"); + + Ok(()) + }) + .await +} diff --git a/sdk/cosmos/azure_data_cosmos_driver/tests/gateway20_pipeline_tests.rs b/sdk/cosmos/azure_data_cosmos_driver/tests/gateway20_pipeline_tests.rs new file mode 100644 index 00000000000..0d5c66f15f0 --- /dev/null +++ b/sdk/cosmos/azure_data_cosmos_driver/tests/gateway20_pipeline_tests.rs @@ -0,0 +1,445 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//! Integration tests that lock in the Gateway 2.0 transport pipeline contract. +//! +//! These tests cover Phase 6 of the Gateway 2.0 specification (see +//! `docs/GATEWAY_20_SPEC.md`). They run as a standalone integration target so +//! they exercise the public surface of the driver crate end-to-end (no +//! `pub(crate)` access). +//! +//! ## Categories +//! +//! 1. **Operator override** — the operator can opt out of Gateway 2.0 even when +//! the account advertises a Gateway 2.0 endpoint. Verified via the public +//! [`ConnectionPoolOptions::with_gateway20_disabled`] toggle. +//! +//! 2. **Operation eligibility** — operations that Gateway 2.0 does not yet +//! support (e.g., stored procedure execution) must transparently fall back +//! to the standard gateway. Documented as an env-gated stub today; the +//! inside-crate routing tests in `operation_pipeline.rs` cover the +//! decision logic. +//! +//! 3. **Diagnostics fidelity** — `RequestDiagnostics` records the actual +//! `TransportKind` used. Documented as an env-gated stub today. +//! +//! 4. **Dual-consistency invariants (V1)** — the V1 HTTP path must never emit +//! *both* the legacy `x-ms-consistency-level` and the newer +//! `x-ms-cosmos-read-consistency-strategy` headers. Asserted via captured +//! HTTP requests through the `__internal_mocking` factory. +//! +//! 5. **Dual-consistency invariants (V2)** — the V2 RNTBD path must never +//! serialize *both* `ConsistencyLevel` and a separate +//! `ReadConsistencyStrategy` token. Documented as an invariant lock; the +//! underlying RNTBD enum currently exposes only the `ConsistencyLevel` +//! token (`tokens.rs`), so the invariant is structurally guaranteed. +//! +//! 6. **Capabilities header pin** — every outgoing request carries +//! `x-ms-cosmos-sdk-supportedcapabilities = "9"`. Asserted via the first +//! captured request through the mock factory. +//! +//! ## Why `__internal_mocking`? +//! +//! Several of these contracts can only be observed at the network boundary. +//! The driver exposes a [`HttpClientFactory`] override under the +//! `__internal_mocking` feature flag specifically for tests like these — it +//! lets us substitute a capturing transport so we can inspect the very first +//! request the runtime emits (the account-properties probe), without ever +//! touching the network. + +#![cfg(feature = "__internal_mocking")] + +use std::sync::{Arc, Mutex}; + +use async_trait::async_trait; +use azure_data_cosmos_driver::models::{AccountReference, CosmosOperation, DatabaseReference}; +use azure_data_cosmos_driver::options::{DriverOptions, OperationOptions}; +use azure_data_cosmos_driver::testing::{ + ConnectionPoolOptions, HttpClientConfig, HttpClientFactory, HttpRequest, HttpResponse, + TransportClient, TransportError, +}; +use azure_data_cosmos_driver::CosmosDriverRuntime; +use url::Url; + +// ---------------------------------------------------------------------------- +// Capturing transport +// ---------------------------------------------------------------------------- + +/// Records every outgoing request. By default, every send returns a +/// connection-style failure so the runtime aborts before the second hop, which +/// keeps the test focused on the first wire frame. +#[derive(Debug, Default)] +struct CapturingTransport { + requests: Mutex>, +} + +impl CapturingTransport { + fn requests(&self) -> Vec { + self.requests + .lock() + .expect("poisoned capture mutex") + .clone() + } +} + +#[async_trait] +impl TransportClient for CapturingTransport { + async fn send(&self, request: &HttpRequest) -> Result { + self.requests + .lock() + .expect("poisoned capture mutex") + .push(request.clone()); + + Err(TransportError::new( + azure_core::Error::with_message( + azure_core::error::ErrorKind::Other, + "capturing transport refuses every request", + ), + azure_data_cosmos_driver::diagnostics::RequestSentStatus::NotSent, + )) + } +} + +#[derive(Debug)] +struct CapturingFactory { + transport: Arc, +} + +impl CapturingFactory { + fn new() -> (Self, Arc) { + let transport = Arc::new(CapturingTransport::default()); + ( + Self { + transport: transport.clone(), + }, + transport, + ) + } +} + +impl HttpClientFactory for CapturingFactory { + fn build( + &self, + _connection_pool: &ConnectionPoolOptions, + _config: HttpClientConfig, + ) -> azure_core::Result> { + Ok(self.transport.clone() as Arc) + } +} + +// ---------------------------------------------------------------------------- +// Helpers +// ---------------------------------------------------------------------------- + +fn fake_account() -> AccountReference { + let url = + Url::parse("https://gw20-pipeline-tests.documents.azure.com/").expect("static URL parses"); + // Master-key value is base64-encoded; the bytes never reach the wire because + // the capturing transport short-circuits every send. + AccountReference::with_master_key(url, "dGVzdC1tYXN0ZXIta2V5") +} + +fn read_env(name: &str) -> Option { + std::env::var(name).ok().filter(|v| !v.trim().is_empty()) +} + +fn live_account_from_env() -> Option { + let endpoint = read_env("AZURE_COSMOS_GW20_ENDPOINT")?; + let key = read_env("AZURE_COSMOS_GW20_KEY")?; + let url = Url::parse(&endpoint).ok()?; + Some(AccountReference::with_master_key(url, key)) +} + +/// Builds a runtime with the capturing factory and the requested +/// gateway-20 toggle. The flag reflects the operator override exposed via +/// `ConnectionPoolOptions` — passing `true` forces every request through the +/// standard gateway transport. +async fn capturing_runtime( + gateway20_disabled: bool, +) -> (Arc, Arc) { + let (factory, transport) = CapturingFactory::new(); + let pool = ConnectionPoolOptions::builder() + .with_gateway20_disabled(gateway20_disabled) + .build() + .expect("connection pool builds"); + let runtime = CosmosDriverRuntime::builder() + .with_connection_pool(pool) + .with_mock_http_client_factory(Arc::new(factory)) + .build() + .await + .expect("runtime builds with mock factory"); + (runtime, transport) +} + +/// Drive a no-op probe so the runtime emits at least one HTTP request. +/// +/// The capturing transport refuses every send, so this always returns an +/// error. We only care about the captured frames. +async fn probe(runtime: &Arc) { + let account = fake_account(); + let options = DriverOptions::builder(account.clone()).build(); + let _ = runtime.get_or_create_driver(account, Some(options)).await; +} + +// ---------------------------------------------------------------------------- +// (a) Operator override forces standard gateway routing +// ---------------------------------------------------------------------------- + +/// Verifies that the operator override flag (`with_gateway20_disabled(true)`) +/// is honored end-to-end at the connection-pool level. When the flag is set, +/// the runtime must not select the Gateway 2.0 transport even if account +/// metadata advertises a Gateway 2.0 endpoint. +/// +/// We assert the contract structurally via `ConnectionPoolOptions`: when the +/// flag is `true`, `gateway20_disabled()` reports `true`, and the +/// transport-layer dispatcher branches to the standard gateway (this branching +/// is covered by the inside-crate tests in +/// `driver::transport::tests::dataplane_transport_*`). +#[tokio::test] +async fn operator_override_disables_gateway20_at_pool_level() { + let off = ConnectionPoolOptions::builder() + .with_gateway20_disabled(true) + .build() + .expect("pool builds"); + assert!( + off.gateway20_disabled(), + "operator-disabled pool must report gateway20_disabled = true" + ); + + let on = ConnectionPoolOptions::builder() + .with_gateway20_disabled(false) + .build() + .expect("pool builds"); + assert!( + !on.gateway20_disabled(), + "operator-enabled pool must report gateway20_disabled = false" + ); +} + +/// Live-account companion to the above. Drives a real read against a +/// pre-provisioned Gateway 2.0 account with the operator override turned off, +/// then asserts (TODO once diagnostics expose `TransportKind`) that the +/// request used the standard gateway transport. +#[tokio::test] +#[ignore = "Requires AZURE_COSMOS_GW20_ENDPOINT/_KEY to a Gateway 2.0 account"] +async fn operator_override_routes_reads_to_standard_gateway() { + let Some(account) = live_account_from_env() else { + return; + }; + + // TODO(Phase 6): once diagnostics expose `TransportKind` per request, + // assert that every request used `TransportKind::StandardGateway`. + let pool = ConnectionPoolOptions::builder() + .with_gateway20_disabled(true) + .build() + .expect("pool builds"); + let runtime = CosmosDriverRuntime::builder() + .with_connection_pool(pool) + .build() + .await + .expect("runtime builds"); + let driver = runtime + .get_or_create_driver(account.clone(), None) + .await + .expect("driver init succeeds against the live account"); + + let db = read_env("AZURE_COSMOS_GW20_DATABASE").unwrap_or_else(|| "gw20-tests".to_string()); + let db_ref = DatabaseReference::from_name(driver.account().clone(), db); + + let _ = driver + .execute_operation( + CosmosOperation::read_database(db_ref), + OperationOptions::default(), + ) + .await; +} + +// ---------------------------------------------------------------------------- +// (b) Operation eligibility fallback (StoredProc Execute → standard gateway) +// ---------------------------------------------------------------------------- + +/// Stored procedure execution is not yet supported by Gateway 2.0 and must +/// fall back to the standard gateway transparently. +/// +/// The eligibility decision is made in `resolve_endpoint` +/// (operation_pipeline.rs); the inside-crate tests in +/// `driver::pipeline::operation_pipeline::tests::resolve_endpoint_*` cover the +/// matrix exhaustively. This standalone test is the live-account contract +/// lock — once `TransportKind` is exposed in diagnostics, assert that the +/// stored-procedure-execute request used `TransportKind::StandardGateway` +/// while a co-located point read on the same account used +/// `TransportKind::Gateway20`. +#[tokio::test] +#[ignore = "Requires AZURE_COSMOS_GW20_ENDPOINT/_KEY plus a stored procedure resource"] +async fn stored_proc_execute_falls_back_to_standard_gateway() { + let Some(_account) = live_account_from_env() else { + return; + }; + // TODO(Phase 6): drive `CosmosOperation::execute_stored_procedure(...)` + // against a real account and assert the diagnostics record + // `TransportKind::StandardGateway` for that request specifically while + // co-located point reads/writes record `TransportKind::Gateway20`. +} + +// ---------------------------------------------------------------------------- +// (c) Diagnostics records TransportKind::Gateway20 +// ---------------------------------------------------------------------------- + +/// Once Gateway 2.0 has dispatched a request, the recorded +/// `RequestDiagnostics` for that request must indicate `TransportKind::Gateway20`. +/// +/// This contract requires a live Gateway 2.0 account. The inside-crate test +/// `transport_pipeline::tests::gateway20_pipeline_records_transport_kind` +/// already covers the wiring at the unit-test level; this standalone test is +/// the live-account companion. +#[tokio::test] +#[ignore = "Requires AZURE_COSMOS_GW20_ENDPOINT/_KEY to a Gateway 2.0 account"] +async fn diagnostics_records_gateway20_transport_kind() { + let Some(_account) = live_account_from_env() else { + return; + }; + // TODO(Phase 6): once `TransportKind` is exposed on the public + // `RequestDiagnostics`, drive a point read against the live Gateway 2.0 + // account and assert the diagnostics report `TransportKind::Gateway20`. +} + +// ---------------------------------------------------------------------------- +// (d) V1 HTTP dual-consistency-header invariant +// ---------------------------------------------------------------------------- + +/// The V1 HTTP path must never emit *both* the legacy +/// `x-ms-consistency-level` header and the newer +/// `x-ms-cosmos-read-consistency-strategy` header on the same request. +/// +/// Today the V1 path emits *neither* header (consistency is propagated via +/// the operation context, not a wire header), so the invariant trivially +/// holds. We capture the first wire frame the runtime emits and assert the +/// pair-presence rule. +#[tokio::test] +async fn v1_http_never_emits_both_consistency_headers() { + const LEGACY: &str = "x-ms-consistency-level"; + const STRATEGY: &str = "x-ms-cosmos-read-consistency-strategy"; + + let (runtime, transport) = capturing_runtime(true).await; + probe(&runtime).await; + + let captured = transport.requests(); + for req in &captured { + let has_legacy = req.headers.iter().any(|(name, _)| name.as_str() == LEGACY); + let has_strategy = req + .headers + .iter() + .any(|(name, _)| name.as_str() == STRATEGY); + assert!( + !(has_legacy && has_strategy), + "request {:?} emitted both '{LEGACY}' and '{STRATEGY}' — V1 invariant violated", + req.url + ); + } +} + +// ---------------------------------------------------------------------------- +// (e) V2 RNTBD dual-consistency-token invariant +// ---------------------------------------------------------------------------- + +/// The V2 (RNTBD) path must never serialize *both* a `ConsistencyLevel` token +/// and a separate `ReadConsistencyStrategy` token on the same wrapped frame. +/// +/// Today the RNTBD token enum +/// (`driver::transport::rntbd::tokens::RntbdRequestToken`) exposes only the +/// `ConsistencyLevel` variant — there is no `ReadConsistencyStrategy` token +/// at all — so the invariant is structurally guaranteed by the type system. +/// This test is therefore a *contract lock* expressed at the boundary this +/// integration test can actually observe. +/// +/// `CapturingTransport` lives at the `HttpClientFactory` layer, so it only +/// ever sees V1 HTTP requests (account-properties probe, metadata reads, +/// etc.). RNTBD frames are dispatched via a separate TCP transport and are +/// invisible here. We assert two things: +/// +/// 1. The capturing transport actually recorded at least one request — i.e. +/// the test setup is wired correctly and the runtime did make outbound +/// progress. +/// 2. Every captured request uses an `http`/`https` scheme. If a future +/// change ever tunnels wrapped RNTBD frames through HTTP (or pushes the +/// capture point lower in the stack so RNTBD is observable here), this +/// assertion fires and forces a reviewer to upgrade the test to parse +/// the wrapped frame and assert at-most-one consistency token per frame. +/// +/// The structural invariant inside the wrapped frame is exhaustively covered +/// by the inside-crate tests in `gateway20_dispatch::tests::wraps_with_*`; +/// this test exists to prevent that coverage from silently disappearing if +/// the V2 transport boundary moves. +#[tokio::test] +async fn v2_rntbd_never_emits_both_consistency_tokens() { + let (runtime, transport) = capturing_runtime(false).await; + probe(&runtime).await; + + let captured = transport.requests(); + assert!( + !captured.is_empty(), + "capturing transport recorded zero requests; the V2 invariant test \ + setup is broken (no traffic was generated at all)" + ); + + // CONTRACT LOCK: today every captured request is a V1 HTTP probe by + // construction. If this assertion ever fails, RNTBD-bearing traffic has + // become observable at the HttpClientFactory layer and the body must be + // structurally decoded to assert mutual exclusion of `ConsistencyLevel` + // and any future `ReadConsistencyStrategy` token. + // + // TODO(Phase 6): when a `ReadConsistencyStrategy` RNTBD token lands, + // replace this scheme check with a structural decode of the wrapped + // frame and assert at-most-one consistency token per wrapped request. + for req in &captured { + let scheme = req.url.scheme(); + assert!( + scheme == "http" || scheme == "https", + "captured request to {} uses scheme {:?}; the V2 dual-token \ + contract lock is invalidated — upgrade this test to parse the \ + wrapped RNTBD frame and assert mutual exclusion of consistency \ + tokens", + req.url, + scheme, + ); + } +} + +// ---------------------------------------------------------------------------- +// (f) Capabilities header pin +// ---------------------------------------------------------------------------- + +/// Every outgoing HTTP request must carry +/// `x-ms-cosmos-sdk-supportedcapabilities: 9`. The bitmask "9" is the +/// concatenation of `PartitionMerge` (1) and `IgnoreUnknownRntbdTokens` (8), +/// which Gateway 2.0 inspects to decide whether the SDK can tolerate unknown +/// RNTBD tokens. +/// +/// This is the load-bearing forward-compatibility advertisement for Gateway +/// 2.0 — it MUST stay pinned to "9" until both bits are coordinated with a +/// service-side rollout. +#[tokio::test] +async fn capabilities_header_value_is_pinned_to_nine() { + const CAPABILITIES: &str = "x-ms-cosmos-sdk-supportedcapabilities"; + + let (runtime, transport) = capturing_runtime(false).await; + probe(&runtime).await; + + let captured = transport.requests(); + assert!( + !captured.is_empty(), + "runtime should have emitted at least one request via the mock factory" + ); + + for req in &captured { + let value = req.headers.iter().find_map(|(name, value)| { + (name.as_str() == CAPABILITIES).then(|| value.as_str().to_owned()) + }); + assert_eq!( + value.as_deref(), + Some("9"), + "capabilities header missing or wrong on request to {}", + req.url + ); + } +} diff --git a/sdk/cosmos/azure_data_cosmos_perf/src/main.rs b/sdk/cosmos/azure_data_cosmos_perf/src/main.rs index 719f52d4cb2..eb4a284c29d 100644 --- a/sdk/cosmos/azure_data_cosmos_perf/src/main.rs +++ b/sdk/cosmos/azure_data_cosmos_perf/src/main.rs @@ -282,10 +282,14 @@ async fn main() -> Result<(), Box> { .ok() .and_then(|v| v.parse::().ok()) .unwrap_or(true), - gateway20_allowed: std::env::var("AZURE_COSMOS_CONNECTION_POOL_IS_GATEWAY20_ALLOWED") - .ok() - .and_then(|v| v.parse::().ok()) - .unwrap_or(false), + // Gateway 2.0 is intentionally not toggled via env var (see + // GATEWAY_20_SPEC.md §3). Until the perf binary wires through the + // public SDK toggle (`CosmosClientOptions::with_gateway20_disabled`), + // it inherits whatever default the SDK ships with — currently + // disabled (pre-GA). + // TODO: Read the actual configured value from the SDK once the + // public toggle lands. + gateway20_disabled: true, pyroscope_enabled: std::env::var("PYROSCOPE_SERVER_URL") .map(|v| !v.is_empty()) .unwrap_or(false), diff --git a/sdk/cosmos/azure_data_cosmos_perf/src/runner.rs b/sdk/cosmos/azure_data_cosmos_perf/src/runner.rs index 1d24329782e..cfb197a1ca9 100644 --- a/sdk/cosmos/azure_data_cosmos_perf/src/runner.rs +++ b/sdk/cosmos/azure_data_cosmos_perf/src/runner.rs @@ -77,7 +77,7 @@ struct PerfResult { #[serde(skip_serializing_if = "Option::is_none")] config_ppcb_enabled: Option, #[serde(skip_serializing_if = "Option::is_none")] - config_gateway20_allowed: Option, + config_gateway20_disabled: Option, #[serde(skip_serializing_if = "Option::is_none")] config_pyroscope_enabled: Option, #[serde(skip_serializing_if = "Option::is_none")] @@ -135,7 +135,7 @@ pub struct ConfigSnapshot { pub excluded_regions: String, pub tokio_threads: u64, pub ppcb_enabled: bool, - pub gateway20_allowed: bool, + pub gateway20_disabled: bool, pub pyroscope_enabled: bool, pub tokio_console_enabled: bool, pub tokio_metrics_enabled: bool, @@ -392,7 +392,7 @@ async fn upsert_results( }, config_tokio_threads: Some(config.tokio_threads), config_ppcb_enabled: Some(config.ppcb_enabled), - config_gateway20_allowed: Some(config.gateway20_allowed), + config_gateway20_disabled: Some(config.gateway20_disabled), config_pyroscope_enabled: Some(config.pyroscope_enabled), config_tokio_console_enabled: Some(config.tokio_console_enabled), config_tokio_metrics_enabled: Some(config.tokio_metrics_enabled), diff --git a/sdk/cosmos/ci.yml b/sdk/cosmos/ci.yml index e12fafc594b..4fa9d99cbbe 100644 --- a/sdk/cosmos/ci.yml +++ b/sdk/cosmos/ci.yml @@ -43,6 +43,21 @@ extends: CloudConfig: Public: ServiceConnection: azure-sdk-tests-cosmos + # Endpoint + master key for the pre-provisioned Gateway 2.0 account, + # surfaced from the `azure-sdk-tests-cosmos` service connection's + # secret variable group. Both `LiveTestMatrixConfigs` entries below see + # these env vars at job time; the standard Cosmos live tests ignore them + # while the Gateway 2.0 matrix entry consumes them via the `gateway20` + # test-category scaffolding (see + # `azure_data_cosmos/tests/emulator_tests/gateway20_e2e.rs`). + # + # This mirrors the Java Cosmos SDK's Gateway 2.0 live-test + # setup (`sdk/cosmos/tests.yml` in `Azure/azure-sdk-for-java`), which + # adds a second matrix entry pointing at a pre-provisioned Gateway 2.0 + # account rather than spinning up a dedicated pipeline. + EnvVars: + AZURE_COSMOS_GW20_ENDPOINT: $(gateway20-test-endpoint) + AZURE_COSMOS_GW20_KEY: $(gateway20-test-key) MatrixConfigs: - Name: Cosmos_release Path: sdk/cosmos/release-platform-matrix.json @@ -54,3 +69,13 @@ extends: Path: sdk/cosmos/live-platform-matrix.json Selection: sparse GenerateVMJobs: true + # Gateway 2.0 live tests run against a pre-provisioned account that is + # NOT created per-pipeline-run. The `ArmTemplateParameters` block in + # `live-gateway20-matrix.json` is kept so the deploy step still fires + # (the matrix machinery requires it), but the provisioned account is + # unused — tests connect to the dedicated Gateway 2.0 account via the + # `AZURE_COSMOS_GW20_ENDPOINT/_KEY` env vars wired above. + - Name: Cosmos_gateway20_live_test + Path: sdk/cosmos/live-gateway20-matrix.json + Selection: sparse + GenerateVMJobs: true diff --git a/sdk/cosmos/live-gateway20-matrix.json b/sdk/cosmos/live-gateway20-matrix.json new file mode 100644 index 00000000000..4075306ca7b --- /dev/null +++ b/sdk/cosmos/live-gateway20-matrix.json @@ -0,0 +1,21 @@ +{ + "displayNames": {}, + "matrix": { + "Agent": { + "ubuntu": { + "OSVmImage": "env:LINUXVMIMAGE", + "Pool": "env:LINUXPOOL" + } + }, + "RustToolchainName": ["stable"], + "Gateway20 Settings": { + "Session SingleRegion": { + "ArmTemplateParameters": "@{ defaultConsistencyLevel = 'Session'; testCategory = 'gateway20' }" + }, + "Session MultiRegion": { + "ArmTemplateParameters": "@{ defaultConsistencyLevel = 'Session'; enableMultipleRegions = $true; testCategory = 'gateway20_multi_region' }" + } + } + }, + "include": [] +}