diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index 8dd5d718b51a..93e7f91b3d5b 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -1,5 +1,10 @@ ## Release History +### 4.14.7 (Unreleased) + +#### Bugs Fixed +* Fixed bug where region names in `preferred_locations` and `excluded_locations` (client-level and per-request) were matched case-sensitively and required exact spacing. See [PR 46792](https://github.com/Azure/azure-sdk-for-python/pull/46792) + ### 4.14.6 (2026-02-02) #### Bugs Fixed diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py index cf8239488712..9bf4bb6a9502 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py @@ -36,6 +36,13 @@ logger = logging.getLogger("azure.cosmos.LocationCache") + +def _normalize_region_name(region_name: str | None) -> str: + if region_name is None: + return "" + normalized = "".join(str(region_name).strip().lower().split()) + return normalized.replace("-", "").replace("_", "") + class EndpointOperationType(object): NoneType = "None" ReadType = "Read" @@ -87,7 +94,7 @@ def _get_health_check_endpoints(regional_routing_contexts) -> Set[str]: return preferred_endpoints def _get_applicable_regional_routing_contexts(regional_routing_contexts: list[RegionalRoutingContext], - location_name_by_endpoint: Mapping[str, str], + normalized_location_name_by_endpoint: Mapping[str, str], fall_back_regional_routing_context: RegionalRoutingContext, exclude_location_list: list[str], circuit_breaker_exclude_list: list[str], @@ -107,8 +114,8 @@ def _get_applicable_regional_routing_contexts(regional_routing_contexts: list[Re :param regional_routing_contexts: The initial list of regional contexts to filter. :type regional_routing_contexts: list[RegionalRoutingContext] - :param location_name_by_endpoint: A mapping from endpoint URL to location name. - :type location_name_by_endpoint: Mapping[str, str] + :param normalized_location_name_by_endpoint: A mapping from endpoint URL to normalized location name. + :type normalized_location_name_by_endpoint: Mapping[str, str] :param fall_back_regional_routing_context: The context to use as a fallback if all others are filtered out. :type fall_back_regional_routing_context: RegionalRoutingContext :param exclude_location_list: A list of location names to exclude, based on user configuration. @@ -120,11 +127,17 @@ def _get_applicable_regional_routing_contexts(regional_routing_contexts: list[Re :return: A filtered and reordered list of regional routing contexts. :rtype: list[RegionalRoutingContext] """ + normalized_excluded_locations = {_normalize_region_name(location) for location in exclude_location_list} + normalized_circuit_breaker_locations = { + _normalize_region_name(location) for location in circuit_breaker_exclude_list + } + # filter endpoints by excluded locations applicable_regional_routing_contexts = [] user_excluded_regional_routing_contexts = [] for regional_routing_context in regional_routing_contexts: - if location_name_by_endpoint.get(regional_routing_context.get_primary()) not in exclude_location_list: + normalized_location_name = normalized_location_name_by_endpoint.get(regional_routing_context.get_primary(), "") + if normalized_location_name not in normalized_excluded_locations: applicable_regional_routing_contexts.append(regional_routing_context) else: user_excluded_regional_routing_contexts.append(regional_routing_context) @@ -133,7 +146,8 @@ def _get_applicable_regional_routing_contexts(regional_routing_contexts: list[Re final_applicable_contexts = [] circuit_breaker_excluded_contexts = [] for regional_routing_context in applicable_regional_routing_contexts: - if location_name_by_endpoint.get(regional_routing_context.get_primary()) in circuit_breaker_exclude_list: + normalized_location_name = normalized_location_name_by_endpoint.get(regional_routing_context.get_primary(), "") + if normalized_location_name in normalized_circuit_breaker_locations: circuit_breaker_excluded_contexts.append(regional_routing_context) else: final_applicable_contexts.append(regional_routing_context) @@ -171,7 +185,14 @@ def __init__( self.account_locations_by_write_endpoints: dict[str, str] = {} # pylint: disable=name-too-long self.account_write_locations: list[str] = [] self.account_read_locations: list[str] = [] + self._read_locations_by_normalized: dict[str, RegionalRoutingContext] = {} + self._write_locations_by_normalized: dict[str, RegionalRoutingContext] = {} + self._normalized_location_by_read_endpoint: dict[str, str] = {} + self._normalized_location_by_write_endpoint: dict[str, str] = {} + self._normalized_name_by_read_location: dict[str, str] = {} + self._normalized_name_by_write_location: dict[str, str] = {} self.connection_policy: ConnectionPolicy = connection_policy + self._config_mismatch_warning_dedupe: set[tuple[str, tuple[str, ...], tuple[str, ...]]] = set() def get_write_regional_routing_contexts(self): return self.write_regional_routing_contexts @@ -228,6 +249,38 @@ def _get_configured_excluded_locations(self, request: RequestObject) -> list[str return excluded_locations + def _emit_config_mismatch_warning_once( + self, + configured_locations: list[str], + available_locations: list[str], + setting_name: str): + if not configured_locations: + return + + available_by_normalized = {_normalize_region_name(location): location for location in available_locations} + unmatched_locations = [ + location + for location in configured_locations + if _normalize_region_name(location) not in available_by_normalized + ] + + if unmatched_locations: + dedupe_key = ( + setting_name, + tuple(sorted(_normalize_region_name(location) for location in unmatched_locations)), + tuple(sorted(available_by_normalized.keys())), + ) + if dedupe_key in self._config_mismatch_warning_dedupe: + return + self._config_mismatch_warning_dedupe.add(dedupe_key) + + logger.warning( + "Ignoring %s entries that did not match account regions: %s. Available regions: %s", + setting_name, + unmatched_locations, + available_locations, + ) + def _get_applicable_read_regional_routing_contexts(self, request: RequestObject) -> list[RegionalRoutingContext]: # Get configured excluded locations excluded_locations = self._get_configured_excluded_locations(request) @@ -236,7 +289,7 @@ def _get_applicable_read_regional_routing_contexts(self, request: RequestObject) if excluded_locations or request.excluded_locations_circuit_breaker: return _get_applicable_regional_routing_contexts( self.get_read_regional_routing_contexts(), - self.account_locations_by_read_endpoints, + self._normalized_location_by_read_endpoint, self.get_write_regional_routing_contexts()[0], excluded_locations, request.excluded_locations_circuit_breaker or [], @@ -253,7 +306,7 @@ def _get_applicable_write_regional_routing_contexts(self, request: RequestObject if excluded_locations or request.excluded_locations_circuit_breaker: return _get_applicable_regional_routing_contexts( self.get_write_regional_routing_contexts(), - self.account_locations_by_write_endpoints, + self._normalized_location_by_write_endpoint, self.default_regional_routing_context, excluded_locations, request.excluded_locations_circuit_breaker or [], @@ -281,6 +334,8 @@ def _resolve_endpoint_without_preferred_locations(self, request, is_write, locat ordered_locations = self.account_write_locations if is_write else self.account_read_locations all_contexts_by_loc = (self.account_write_regional_routing_contexts_by_location if is_write else self.account_read_regional_routing_contexts_by_location) + normalized_name_by_location = (self._normalized_name_by_write_location if is_write + else self._normalized_name_by_read_location) # Safety check: if endpoint discovery is off or location cache isn't populated, fallback. if not self.connection_policy.EnableEndpointDiscovery or not ordered_locations: @@ -298,14 +353,20 @@ def _resolve_endpoint_without_preferred_locations(self, request, is_write, locat excluded_locations = self._get_configured_excluded_locations(request) circuit_breaker_excluded_locations = request.excluded_locations_circuit_breaker or [] + normalized_excluded_locations = {_normalize_region_name(location) for location in excluded_locations} + normalized_circuit_breaker_locations = { + _normalize_region_name(location) for location in circuit_breaker_excluded_locations + } + applicable_contexts = [] circuit_breaker_contexts = [] for loc_name in ordered_locations: if loc_name in all_contexts_by_loc: context = all_contexts_by_loc[loc_name] - if loc_name in excluded_locations: + normalized_location_name = normalized_name_by_location.get(loc_name, "") + if normalized_location_name in normalized_excluded_locations: continue # Skip user-excluded locations - if loc_name in circuit_breaker_excluded_locations: + if normalized_location_name in normalized_circuit_breaker_locations: circuit_breaker_contexts.append(context) else: applicable_contexts.append(context) @@ -365,6 +426,11 @@ def resolve_service_endpoint(self, request): def should_refresh_endpoints(self): # pylint: disable=too-many-return-statements most_preferred_location = self.effective_preferred_locations[0] if self.effective_preferred_locations else None + normalized_most_preferred_location = ( + _normalize_region_name(most_preferred_location) if most_preferred_location else None + ) + read_locations_by_normalized = self._read_locations_by_normalized + write_locations_by_normalized = self._write_locations_by_normalized # we should schedule refresh in background if we are unable to target the user's most preferredLocation. if self.connection_policy.EnableEndpointDiscovery: @@ -372,18 +438,13 @@ def should_refresh_endpoints(self): # pylint: disable=too-many-return-statement should_refresh = (self.connection_policy.UseMultipleWriteLocations and not self.enable_multiple_writable_locations) - if (most_preferred_location and most_preferred_location in - self.account_read_regional_routing_contexts_by_location): - if (self.account_read_regional_routing_contexts_by_location - and most_preferred_location in self.account_read_regional_routing_contexts_by_location): - most_preferred_read_endpoint = ( - self.account_read_regional_routing_contexts_by_location)[most_preferred_location] - if (most_preferred_read_endpoint and - most_preferred_read_endpoint != self.read_regional_routing_contexts[0]): - # For reads, we can always refresh in background as we can alternate to - # other available read endpoints - return True - else: + if (normalized_most_preferred_location and normalized_most_preferred_location in + read_locations_by_normalized): + most_preferred_read_endpoint = read_locations_by_normalized[normalized_most_preferred_location] + if (most_preferred_read_endpoint and + most_preferred_read_endpoint != self.read_regional_routing_contexts[0]): + # For reads, we can always refresh in background as we can alternate to + # other available read endpoints return True if not self.can_use_multiple_write_locations(): @@ -394,10 +455,11 @@ def should_refresh_endpoints(self): # pylint: disable=too-many-return-statement # we have an alternate write endpoint return True return should_refresh - if (most_preferred_location and - most_preferred_location in self.account_write_regional_routing_contexts_by_location): - most_preferred_write_regional_endpoint = ( - self.account_write_regional_routing_contexts_by_location)[most_preferred_location] + if (normalized_most_preferred_location and + normalized_most_preferred_location in write_locations_by_normalized): + most_preferred_write_regional_endpoint = write_locations_by_normalized[ + normalized_most_preferred_location + ] if most_preferred_write_regional_endpoint: should_refresh |= most_preferred_write_regional_endpoint != self.write_regional_routing_contexts[0] return should_refresh @@ -465,6 +527,32 @@ def update_location_cache(self, write_locations=None, read_locations=None, enabl self.account_locations_by_write_endpoints, self.account_write_locations) = get_regional_routing_contexts_by_loc(write_locations) + # Cache normalized lookups once per topology refresh to avoid repeating work per request. + self._read_locations_by_normalized = { + _normalize_region_name(name): context + for name, context in self.account_read_regional_routing_contexts_by_location.items() + } + self._write_locations_by_normalized = { + _normalize_region_name(name): context + for name, context in self.account_write_regional_routing_contexts_by_location.items() + } + self._normalized_location_by_read_endpoint = { + endpoint: _normalize_region_name(name) + for endpoint, name in self.account_locations_by_read_endpoints.items() + } + self._normalized_location_by_write_endpoint = { + endpoint: _normalize_region_name(name) + for endpoint, name in self.account_locations_by_write_endpoints.items() + } + self._normalized_name_by_read_location = { + name: _normalize_region_name(name) + for name in self.account_read_regional_routing_contexts_by_location + } + self._normalized_name_by_write_location = { + name: _normalize_region_name(name) + for name in self.account_write_regional_routing_contexts_by_location + } + # if preferred locations is empty and the default endpoint is a global endpoint, # we should use the read locations from gateway as effective preferred locations if self.connection_policy.PreferredLocations: @@ -478,17 +566,40 @@ def update_location_cache(self, write_locations=None, read_locations=None, enabl self.account_write_regional_routing_contexts_by_location, self.account_write_locations, EndpointOperationType.WriteType, - self.default_regional_routing_context + self.default_regional_routing_context, + self._write_locations_by_normalized, ) self.read_regional_routing_contexts = self.get_preferred_regional_routing_contexts( self.account_read_regional_routing_contexts_by_location, self.account_read_locations, EndpointOperationType.ReadType, - self.write_regional_routing_contexts[0] + self.write_regional_routing_contexts[0], + self._read_locations_by_normalized, ) + # Config-time visibility for misconfigured region names. Dedupe ensures periodic + # refreshes do not re-emit identical warnings; new mismatches still surface because + # the dedupe key includes the available account regions snapshot. + if self.connection_policy.PreferredLocations: + self._emit_config_mismatch_warning_once( + self.connection_policy.PreferredLocations, + self.account_read_locations or self.account_write_locations, + "preferred_locations", + ) + if self.connection_policy.ExcludedLocations: + self._emit_config_mismatch_warning_once( + list(self.connection_policy.ExcludedLocations), + self.account_read_locations or self.account_write_locations, + "excluded_locations", + ) + def get_preferred_regional_routing_contexts( - self, endpoints_by_location, orderedLocations, expected_available_operation, fallback_endpoint + self, + endpoints_by_location, + orderedLocations, + expected_available_operation, + fallback_endpoint, + endpoints_by_normalized_location=None, ): regional_endpoints = [] # if enableEndpointDiscovery is false, we always use the defaultEndpoint that @@ -500,12 +611,18 @@ def get_preferred_regional_routing_contexts( ): unavailable_endpoints = [] if self.effective_preferred_locations: + endpoints_by_normalized_location = endpoints_by_normalized_location or { + _normalize_region_name(location): endpoint + for location, endpoint in endpoints_by_location.items() + } + # When client can not use multiple write locations, preferred locations # list should only be used determining read endpoints order. If client # can use multiple write locations, preferred locations list should be # used for determining both read and write endpoints order. for location in self.effective_preferred_locations: - regional_endpoint = endpoints_by_location.get(location) + normalized_location = _normalize_region_name(location) + regional_endpoint = endpoints_by_normalized_location.get(normalized_location) if regional_endpoint: if self.is_endpoint_unavailable(regional_endpoint.get_primary(), expected_available_operation): @@ -579,8 +696,8 @@ def GetLocationalEndpoint(default_endpoint, location_name): global_database_account_name = hostname_parts[0] # Prepare the locational_database_account_name as contoso-eastus for location_name 'east us' - locational_database_account_name = global_database_account_name + "-" + location_name.replace(" ", "") - locational_database_account_name = locational_database_account_name.lower() + normalized_location_name = _normalize_region_name(location_name) + locational_database_account_name = global_database_account_name + "-" + normalized_location_name # Replace 'contoso' with 'contoso-eastus' and return locational_endpoint # as https://contoso-eastus.documents.azure.com:443/ diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_version.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_version.py index af83f50bae55..d581f6fc3808 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_version.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_version.py @@ -19,4 +19,4 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -VERSION = "4.14.6" +VERSION = "4.14.7" diff --git a/sdk/cosmos/azure-cosmos/tests/test_location_cache.py b/sdk/cosmos/azure-cosmos/tests/test_location_cache.py index 090717519222..3c25f41fed3f 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_location_cache.py +++ b/sdk/cosmos/azure-cosmos/tests/test_location_cache.py @@ -26,6 +26,11 @@ location4_endpoint = "https://location4.documents.azure.com" refresh_time_interval_in_ms = 1000 +canonical_location1_name = "East US 2" +canonical_location2_name = "West US 3" +canonical_location1_endpoint = "https://eastus2.documents.azure.com" +canonical_location2_endpoint = "https://westus3.documents.azure.com" + def create_database_account(enable_multiple_writable_locations): db_acc = DatabaseAccount() @@ -46,6 +51,20 @@ def refresh_location_cache(preferred_locations, use_multiple_write_locations, co connection_policy=connection_policy) return lc + +def create_database_account_with_canonical_regions(enable_multiple_writable_locations): + db_acc = DatabaseAccount() + db_acc._WritableLocations = [ + {"name": canonical_location1_name, "databaseAccountEndpoint": canonical_location1_endpoint}, + {"name": canonical_location2_name, "databaseAccountEndpoint": canonical_location2_endpoint}, + ] + db_acc._ReadableLocations = [ + {"name": canonical_location1_name, "databaseAccountEndpoint": canonical_location1_endpoint}, + {"name": canonical_location2_name, "databaseAccountEndpoint": canonical_location2_endpoint}, + ] + db_acc._EnableMultipleWritableLocations = enable_multiple_writable_locations + return db_acc + @pytest.mark.cosmosEmulator class TestLocationCache: @@ -390,6 +409,28 @@ def test_resolve_endpoint_respects_excluded_regions_when_use_preferred_locations # Assert the correct behavior. assert resolved_endpoint == location2_endpoint + def test_resolve_endpoint_without_preferred_locations_supports_normalized_exclusions(self): + # This specifically exercises _resolve_endpoint_without_preferred_locations by + # setting use_preferred_locations=False. + lc = refresh_location_cache( + preferred_locations=[], + use_multiple_write_locations=True, + ) + db_acc = create_database_account_with_canonical_regions(enable_multiple_writable_locations=True) + lc.perform_on_database_account_read(db_acc) + + write_request = RequestObject(ResourceType.Document, _OperationType.Create, None) + write_request.use_preferred_locations = False + write_request.excluded_locations = ["east-us-2"] + + assert lc.resolve_service_endpoint(write_request) == canonical_location2_endpoint + + read_request = RequestObject(ResourceType.Document, _OperationType.Read, None) + read_request.use_preferred_locations = False + read_request.excluded_locations = ["west_us_3"] + + assert lc.resolve_service_endpoint(read_request) == canonical_location1_endpoint + def test_regional_fallback_when_primary_is_excluded(self): # This test simulates a scenario where the primary preferred region is excluded # by the user, and the secondary is excluded by the circuit breaker. @@ -472,5 +513,88 @@ def test_write_fallback_to_global_after_regional_retries_exhausted(self): final_endpoint = lc.resolve_service_endpoint(write_request) assert final_endpoint == location1_endpoint + def test_preferred_locations_support_normalized_region_names(self): + # Preferred locations should match account region names even with case/spacing/separator variations. + lc = refresh_location_cache(["east-us-2", " west_us_3 "], True) + db_acc = create_database_account_with_canonical_regions(enable_multiple_writable_locations=True) + lc.perform_on_database_account_read(db_acc) + + write_contexts = lc.get_write_regional_routing_contexts() + read_contexts = lc.get_read_regional_routing_contexts() + + assert write_contexts[0].get_primary() == canonical_location1_endpoint + assert write_contexts[1].get_primary() == canonical_location2_endpoint + assert read_contexts[0].get_primary() == canonical_location1_endpoint + assert read_contexts[1].get_primary() == canonical_location2_endpoint + + def test_excluded_locations_support_normalized_region_names(self): + # Excluded locations should filter regions even when normalized names are used. + connection_policy = documents.ConnectionPolicy() + connection_policy.ExcludedLocations = ["east-us-2"] + + lc = refresh_location_cache([canonical_location1_name, canonical_location2_name], True, connection_policy) + db_acc = create_database_account_with_canonical_regions(enable_multiple_writable_locations=True) + lc.perform_on_database_account_read(db_acc) + + read_request = RequestObject(ResourceType.Document, _OperationType.Read, None) + write_request = RequestObject(ResourceType.Document, _OperationType.Create, None) + write_request.excluded_locations = ["west_us_3"] + + assert lc.resolve_service_endpoint(read_request) == canonical_location2_endpoint + assert lc.resolve_service_endpoint(write_request) == canonical_location1_endpoint + + def test_should_refresh_endpoints_handles_normalized_preferred_region(self): + # should_refresh_endpoints must match canonical region keys even when the + # customer's preferred location uses non-canonical spelling. + lc = refresh_location_cache(["east-us-2"], True) + db_acc = create_database_account_with_canonical_regions(enable_multiple_writable_locations=True) + lc.perform_on_database_account_read(db_acc) + + # Most-preferred is already the primary; no background refresh should be triggered. + assert lc.should_refresh_endpoints() is False + + def test_get_locational_endpoint_normalizes_customer_region_string(self): + # GetLocationalEndpoint is used during bootstrap fallback with the customer-supplied + # preferred region string. It must produce the canonical regional URL for any + # accepted normalization variant. + default_endpoint_url = "https://contoso.documents.azure.com:443/" + expected_endpoint = "https://contoso-eastus2.documents.azure.com:443/" + + for region_input in ("East US 2", "east us 2", "eastus2", "east-us-2", "east_us_2", " EastUs2 "): + assert LocationCache.GetLocationalEndpoint(default_endpoint_url, region_input) == expected_endpoint + + def test_unmatched_excluded_locations_warning_is_deduped(self, caplog): + connection_policy = documents.ConnectionPolicy() + connection_policy.ExcludedLocations = ["unknown-region"] + lc = refresh_location_cache([canonical_location1_name], True, connection_policy) + db_acc = create_database_account_with_canonical_regions(enable_multiple_writable_locations=True) + with caplog.at_level("WARNING", logger="azure.cosmos.LocationCache"): + lc.perform_on_database_account_read(db_acc) + request = RequestObject(ResourceType.Document, _OperationType.Read, None) + lc.resolve_service_endpoint(request) + lc.resolve_service_endpoint(request) + # Simulate a periodic refresh with unchanged topology and config. + lc.perform_on_database_account_read(db_acc) + + unmatched_logs = [ + record for record in caplog.records + if "Ignoring excluded_locations entries" in record.getMessage() + ] + assert len(unmatched_logs) == 1 + + def test_unmatched_preferred_locations_warning_is_deduped(self, caplog): + with caplog.at_level("WARNING", logger="azure.cosmos.LocationCache"): + lc = refresh_location_cache(["unknown-region"], True) + db_acc = create_database_account_with_canonical_regions(enable_multiple_writable_locations=True) + lc.perform_on_database_account_read(db_acc) + # Simulate a periodic refresh with unchanged topology and config. + lc.perform_on_database_account_read(db_acc) + + unmatched_logs = [ + record for record in caplog.records + if "Ignoring preferred_locations entries" in record.getMessage() + ] + assert len(unmatched_logs) == 1 + if __name__ == "__main__": - unittest.main() + unittest.main() \ No newline at end of file