diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index b8b1c451ad10..a5d7a6fcaef9 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -7,6 +7,7 @@ #### Breaking Changes #### Bugs Fixed +* Fixed bug where the SDK could not connect to the local Cosmos DB emulator running in Docker with a remapped host port. The emulator advertises its internal host/port (e.g. `127.0.0.1:8081`) in its account topology, which is unreachable when the host port differs from `8081`. When the user-supplied endpoint targets `localhost` or `127.0.0.1`, the SDK now reuses that host/port for all regional endpoints returned by the gateway. See [PR 46896](https://github.com/Azure/azure-sdk-for-python/pull/46896) #### Other Changes diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py index 1efd6d4841df..58648b5dec16 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py @@ -127,11 +127,83 @@ def __eq__(self, other): def __str__(self): return "Primary: " + self.primary_endpoint +def _is_local_emulator_endpoint(endpoint: Optional[str]) -> bool: + """Return True if the endpoint refers to the local Cosmos DB emulator. + + Hosts ``localhost`` and ``127.0.0.1`` are treated as emulator endpoints. + + :param endpoint: The endpoint URL to inspect, or ``None``. + :type endpoint: str or None + :returns: ``True`` if the endpoint's hostname is ``localhost`` or + ``127.0.0.1``; ``False`` otherwise (including when ``endpoint`` is + ``None`` / empty or cannot be parsed). + :rtype: bool + """ + if not endpoint: + return False + try: + hostname = urlparse(endpoint).hostname + except ValueError: + return False + return hostname in ("localhost", "127.0.0.1") + + +def _rewrite_endpoint_with_default(default_endpoint: str, regional_endpoint: str) -> str: + """Rewrite ``regional_endpoint``'s scheme/host/port to match ``default_endpoint``. + + The Cosmos DB emulator advertises its internal host/port (for example + ``127.0.0.1:8081``) in the database account topology. When the emulator + is running in a container with a remapped port, that advertised endpoint + is unreachable from the host. Rewriting it to the user-supplied endpoint + preserves connectivity while keeping the rest of the URI (path, etc.) intact. + + When both endpoints already advertise the same explicit port, the rewrite + is skipped so legitimate hostname differences are preserved (for example, + test setups that simulate multiple regions by advertising different + hostnames against the same emulator instance). + + :param str default_endpoint: The user-supplied account endpoint whose + scheme / host / port should be copied onto ``regional_endpoint``. + :param str regional_endpoint: The endpoint advertised by the gateway for + a specific region (typically the emulator's internal host:port). + :returns: A URL string with ``regional_endpoint``'s path/query preserved + and its scheme/netloc replaced with the values from + ``default_endpoint``. The input ``regional_endpoint`` is returned + unchanged if ``default_endpoint`` cannot be parsed, has no netloc, + or already shares the same explicit port. + :rtype: str + """ + try: + default_parsed = urlparse(default_endpoint) + regional_parsed = urlparse(regional_endpoint) + default_port = default_parsed.port + regional_port = regional_parsed.port + except ValueError: + return regional_endpoint + if not default_parsed.netloc: + return regional_endpoint + if default_port is not None and default_port == regional_port: + # Ports already match — rewriting would only collapse legitimate + # hostname distinctions (e.g. fault-injection tests that simulate + # multiple regions by advertising different hostnames against the + # same emulator instance). Leave the advertised endpoint untouched. + return regional_endpoint + return regional_parsed._replace( + scheme=default_parsed.scheme or regional_parsed.scheme, + netloc=default_parsed.netloc, + ).geturl() + + +def get_regional_routing_contexts_by_loc( + new_locations: list[dict[str, str]], + default_endpoint: Optional[str] = None, +): def get_regional_routing_contexts_by_loc(new_locations: list[dict[str, str]]): # construct from previous object regional_routing_contexts_by_location: OrderedDict[str, RegionalRoutingContext] = collections.OrderedDict() parsed_locations = [] + rewrite_to_default = _is_local_emulator_endpoint(default_endpoint) for new_location in new_locations: # if name in new_location and same for database account endpoint @@ -141,6 +213,12 @@ def get_regional_routing_contexts_by_loc(new_locations: list[dict[str, str]]): continue try: region_uri = new_location["databaseAccountEndpoint"] + if rewrite_to_default and default_endpoint is not None: + # When targeting the local emulator the server can advertise an + # internal host/port (e.g. 127.0.0.1:8081) that is unreachable + # from the caller (common with Docker port remapping). Reuse + # the user-supplied endpoint host/port so connections succeed. + region_uri = _rewrite_endpoint_with_default(default_endpoint, region_uri) parsed_locations.append(new_location["name"]) regional_object = RegionalRoutingContext(region_uri) regional_routing_contexts_by_location.update({new_location["name"]: regional_object}) @@ -596,15 +674,18 @@ def update_location_cache(self, write_locations=None, read_locations=None, enabl self.enable_multiple_writable_locations = enable_multiple_writable_locations if self.connection_policy.EnableEndpointDiscovery: + default_endpoint = self.default_regional_routing_context.get_primary() if read_locations: (self.account_read_regional_routing_contexts_by_location, self.account_locations_by_read_endpoints, - self.account_read_locations) = get_regional_routing_contexts_by_loc(read_locations) + self.account_read_locations) = get_regional_routing_contexts_by_loc( + read_locations, default_endpoint) if write_locations: (self.account_write_regional_routing_contexts_by_location, self.account_locations_by_write_endpoints, - self.account_write_locations) = get_regional_routing_contexts_by_loc(write_locations) + self.account_write_locations) = get_regional_routing_contexts_by_loc( + write_locations, default_endpoint) # Cache normalized lookups once per topology refresh to avoid repeating work per request. self._read_locations_by_normalized = { diff --git a/sdk/cosmos/azure-cosmos/tests/test_location_cache.py b/sdk/cosmos/azure-cosmos/tests/test_location_cache.py index af0861929cbb..993b0a1c5fba 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_location_cache.py +++ b/sdk/cosmos/azure-cosmos/tests/test_location_cache.py @@ -691,6 +691,119 @@ def test_location_cache_derived_state_consistency(self): assert read_after_second == [ctx.get_primary() for ctx in expected_read] assert write_after_second == [ctx.get_primary() for ctx in expected_write] + +class TestEmulatorEndpointRewrite: + """Tests that emulator setups (localhost / 127.0.0.1) ignore the host:port + advertised by the gateway and reuse the user-supplied endpoint instead. + + This addresses the issue where the Cosmos emulator running in Docker with + a remapped port (e.g. host port 8888 -> container port 8081) advertises its + internal port back to the client, making the returned regional endpoints + unreachable from the host. + """ + + @staticmethod + def _make_db_account(advertised_endpoint): + db_acc = DatabaseAccount() + db_acc._WritableLocations = [ + {"name": "South Central US", "databaseAccountEndpoint": advertised_endpoint} + ] + db_acc._ReadableLocations = [ + {"name": "South Central US", "databaseAccountEndpoint": advertised_endpoint} + ] + db_acc._EnableMultipleWritableLocations = False + return db_acc + + @pytest.mark.parametrize("user_endpoint", [ + "http://localhost:8888/", + "http://127.0.0.1:9000/", + ]) + def test_emulator_endpoint_is_preserved(self, user_endpoint): + connection_policy = documents.ConnectionPolicy() + lc = LocationCache(default_endpoint=user_endpoint, connection_policy=connection_policy) + db_acc = self._make_db_account("https://127.0.0.1:8081/") + + lc.perform_on_database_account_read(db_acc) + + write_contexts = lc.get_write_regional_routing_contexts() + read_contexts = lc.get_read_regional_routing_contexts() + assert len(write_contexts) == 1 + assert len(read_contexts) == 1 + # The advertised 127.0.0.1:8081 host:port should be replaced with the + # user-supplied host:port so the SDK can reach the emulator. + assert write_contexts[0].get_primary() == user_endpoint + assert read_contexts[0].get_primary() == user_endpoint + + def test_emulator_matching_port_preserves_advertised_host(self): + # When the user-supplied endpoint and the advertised endpoint already + # use the same port, the rewrite is intentionally skipped so the + # advertised hostname is preserved. This matters for test + # infrastructure (e.g. FaultInjectionTransport) that simulates + # multiple regions by advertising different hostnames (localhost vs + # 127.0.0.1) against the same physical emulator instance. + user_endpoint = "https://localhost:8081/" + advertised_endpoint = "https://127.0.0.1:8081/" + connection_policy = documents.ConnectionPolicy() + lc = LocationCache(default_endpoint=user_endpoint, connection_policy=connection_policy) + db_acc = self._make_db_account(advertised_endpoint) + + lc.perform_on_database_account_read(db_acc) + + write_contexts = lc.get_write_regional_routing_contexts() + read_contexts = lc.get_read_regional_routing_contexts() + assert write_contexts[0].get_primary() == advertised_endpoint + assert read_contexts[0].get_primary() == advertised_endpoint + + def test_non_emulator_endpoints_are_not_rewritten(self): + user_endpoint = "https://contoso.documents.azure.com:443/" + advertised_endpoint = "https://contoso-southcentralus.documents.azure.com:443/" + connection_policy = documents.ConnectionPolicy() + lc = LocationCache(default_endpoint=user_endpoint, connection_policy=connection_policy) + db_acc = self._make_db_account(advertised_endpoint) + + lc.perform_on_database_account_read(db_acc) + + write_contexts = lc.get_write_regional_routing_contexts() + read_contexts = lc.get_read_regional_routing_contexts() + assert write_contexts[0].get_primary() == advertised_endpoint + assert read_contexts[0].get_primary() == advertised_endpoint + + def test_emulator_endpoint_with_advertised_localhost_is_rewritten(self): + # Even when the advertised endpoint is also a localhost address (just + # with a different port like the in-container 8081), it should still + # be rewritten to the user-supplied host:port. + user_endpoint = "http://localhost:8888/" + advertised_endpoint = "http://localhost:8081/" + connection_policy = documents.ConnectionPolicy() + lc = LocationCache(default_endpoint=user_endpoint, connection_policy=connection_policy) + db_acc = self._make_db_account(advertised_endpoint) + + lc.perform_on_database_account_read(db_acc) + + write_contexts = lc.get_write_regional_routing_contexts() + assert write_contexts[0].get_primary() == user_endpoint + + def test_endpoint_discovery_disabled_skips_rewrite(self): + # When endpoint discovery is disabled, update_location_cache short-circuits + # before populating the per-region routing contexts at all, so the rewrite + # path is never reached and the SDK falls back to the user-supplied + # default endpoint for every request. + user_endpoint = "http://localhost:8888/" + advertised_endpoint = "https://127.0.0.1:8081/" + connection_policy = documents.ConnectionPolicy() + connection_policy.EnableEndpointDiscovery = False + lc = LocationCache(default_endpoint=user_endpoint, connection_policy=connection_policy) + db_acc = self._make_db_account(advertised_endpoint) + + lc.perform_on_database_account_read(db_acc) + + # No per-region contexts are populated when endpoint discovery is off. + assert lc.account_write_regional_routing_contexts_by_location == {} + assert lc.account_read_regional_routing_contexts_by_location == {} + # Routing falls back to the user-supplied default endpoint, not the + # gateway-advertised 127.0.0.1:8081. + assert lc.get_write_regional_routing_context() == user_endpoint + assert lc.get_read_regional_routing_context() == user_endpoint def test_resolve_endpoint_without_preferred_locations_supports_normalized_exclusions(self): # This specifically exercises _resolve_endpoint_without_preferred_locations by # setting use_preferred_locations=False.