diff --git a/src/azure-cli/azure/cli/command_modules/cosmosdb/custom.py b/src/azure-cli/azure/cli/command_modules/cosmosdb/custom.py index 23c00ad47db..6a30862237d 100644 --- a/src/azure-cli/azure/cli/command_modules/cosmosdb/custom.py +++ b/src/azure-cli/azure/cli/command_modules/cosmosdb/custom.py @@ -273,6 +273,11 @@ def _create_database_account(client, locations = [] locations.append(Location(location_name=arm_location, failover_priority=0, is_zone_redundant=False)) + for loc in locations: + if loc.failover_priority == 0: + arm_location = loc.location_name + break + managed_service_identity = None SYSTEM_ID = '[system]' enable_system = False @@ -409,8 +414,22 @@ def _create_database_account(client, ) async_docdb_create = client.begin_create_or_update(resource_group_name, account_name, params) - docdb_account = async_docdb_create.result() - docdb_account = client.get(resource_group_name, account_name) # Workaround + try: + docdb_account = async_docdb_create.result() + except HttpResponseError as ex: + message = str(ex) + if (is_restore_request and + ex.status_code == 403 and + "does not exist" in message and + ("Database Account" in message or "Forbidden" in message)): + logger.warning( + "Encountered known service issue (403 'does not exist') while restoring Cosmos DB account '%s' " + "in resource group '%s'. Using client.get() as a workaround. Raw error: %s", + account_name, resource_group_name, ex + ) + docdb_account = client.get(resource_group_name, account_name) + else: + raise ex return docdb_account @@ -3518,6 +3537,24 @@ def cli_offline_region(client, resource_group_name, region): + # Function to normalize region name + def _normalize_region(region_name): + return region_name.replace(' ', '').lower() + + # Get the account to check for the region name + account = client.get(resource_group_name, account_name) + input_region_normalized = _normalize_region(region) + matched_region = None + + # Check matches in both read and write locations + for loc in account.locations: + if _normalize_region(loc.location_name) == input_region_normalized: + matched_region = loc.location_name + break + + if matched_region: + region = matched_region + region_parameter_for_offline = RegionForOnlineOffline(region=region) return client.begin_offline_region( resource_group_name=resource_group_name, diff --git a/src/azure-cli/azure/cli/command_modules/cosmosdb/tests/latest/test_cosmosdb_backuprestore_scenario.py b/src/azure-cli/azure/cli/command_modules/cosmosdb/tests/latest/test_cosmosdb_backuprestore_scenario.py index fb5b596c383..e262f148547 100644 --- a/src/azure-cli/azure/cli/command_modules/cosmosdb/tests/latest/test_cosmosdb_backuprestore_scenario.py +++ b/src/azure-cli/azure/cli/command_modules/cosmosdb/tests/latest/test_cosmosdb_backuprestore_scenario.py @@ -4,6 +4,7 @@ # -------------------------------------------------------------------------------------------- import os +import sys import unittest from unittest import mock @@ -518,4 +519,167 @@ def test_cosmosdb_xrr_single_region_account(self, resource_group): assert restored_account['restoreParameters']['restoreSource'] == restorable_database_account['id'] assert restored_account['restoreParameters']['restoreTimestampInUtc'] == restore_ts_string assert restored_account['restoreParameters']['sourceBackupLocation'] == source_loc_for_xrr - assert restored_account['writeLocations'][0]['locationName'] == 'North Central US' \ No newline at end of file + assert restored_account['writeLocations'][0]['locationName'] == 'North Central US' + + +class CosmosDBRestoreUnitTests(unittest.TestCase): + def setUp(self): + # Mock dependencies that might be missing or problematic to import + if 'azure.mgmt.cosmosdb.models' not in sys.modules: + sys.modules['azure.mgmt.cosmosdb.models'] = mock.MagicMock() + if 'azure.cli.core.util' not in sys.modules: + sys.modules['azure.cli.core.util'] = mock.MagicMock() + if 'knack.log' not in sys.modules: + sys.modules['knack.log'] = mock.MagicMock() + # Mocking knack.util.CLIError is crucial if it's used in custom.py + if 'knack.util' not in sys.modules: + mock_knack_util = mock.MagicMock() + mock_knack_util.CLIError = Exception + sys.modules['knack.util'] = mock_knack_util + + # Ensure Azure Core Exceptions are available + try: + import azure.core.exceptions + except ImportError: + mock_core_exceptions = mock.MagicMock() + # Define minimal exception class + class HttpResponseError(Exception): + def __init__(self, message=None, response=None, **kwargs): + self.message = message + self.response = response + self.status_code = kwargs.get('status_code', None) + def __str__(self): + return self.message or "" + mock_core_exceptions.HttpResponseError = HttpResponseError + mock_core_exceptions.ResourceNotFoundError = Exception + sys.modules['azure.core.exceptions'] = mock_core_exceptions + + def test_restore_handles_forbidden_error(self): + from azure.core.exceptions import HttpResponseError + # Lazy import to ensure mocks are applied first + from azure.cli.command_modules.cosmosdb.custom import _create_database_account + + # Setup mocks + client = mock.MagicMock() + + # Simulate the LRO poller raising the specific error + poller = mock.MagicMock() + error_json = '{"code":"Forbidden","message":"Database Account riks-models-003-acc-westeurope does not exist"}' + exception = HttpResponseError(message=error_json) + exception.status_code = 403 + + # side_effect raises the exception when called + poller.result.side_effect = exception + client.begin_create_or_update.return_value = poller + + # Simulate client.get returning the account successfully + mock_account = mock.MagicMock() + mock_account.provisioning_state = "Succeeded" + client.get.return_value = mock_account + + # Parameters + resource_group_name = "rg" + account_name = "myaccount" + + # Call the private function directly to verify logic + result = _create_database_account( + client=client, + resource_group_name=resource_group_name, + account_name=account_name, + locations=[], + is_restore_request=True, + arm_location="westeurope", + restore_source="/subscriptions/sub/providers/Microsoft.DocumentDB/locations/westeurope/restorableDatabaseAccounts/source-id", + restore_timestamp="2026-01-01T00:00:00+00:00" + ) + + # Assertions + # 1. begin_create_or_update called + client.begin_create_or_update.assert_called() + # 2. poller.result() called (and raised exception) + poller.result.assert_called() + # 3. client.get called (recovery mechanism) + client.get.assert_called_with(resource_group_name, account_name) + # 4. Result is the account returned by get + self.assertEqual(result, mock_account) + + def test_restore_raises_other_errors(self): + from azure.core.exceptions import HttpResponseError + from azure.cli.command_modules.cosmosdb.custom import _create_database_account + + # Setup mocks + client = mock.MagicMock() + poller = mock.MagicMock() + + # Different error + exception = HttpResponseError(message="Some other error") + exception.status_code = 500 + poller.result.side_effect = exception + client.begin_create_or_update.return_value = poller + + with self.assertRaises(HttpResponseError): + _create_database_account( + client=client, + resource_group_name="rg", + account_name="myaccount", + is_restore_request=True, + arm_location="westeurope", + restore_source="src", + restore_timestamp="ts" + ) + + def test_normal_create_does_not_suppress_error(self): + from azure.core.exceptions import HttpResponseError + from azure.cli.command_modules.cosmosdb.custom import _create_database_account + + # Setup mocks + client = mock.MagicMock() + poller = mock.MagicMock() + + # Same error but NOT a restore request + error_json = '{"code":"Forbidden","message":"Database Account riks-models-003-acc-westeurope does not exist"}' + exception = HttpResponseError(message=error_json) + exception.status_code = 403 + poller.result.side_effect = exception + client.begin_create_or_update.return_value = poller + + with self.assertRaises(HttpResponseError): + _create_database_account( + client=client, + resource_group_name="rg", + account_name="myaccount", + is_restore_request=False, # Normal create + arm_location="westeurope" + ) + + def test_normal_create_success(self): + from azure.cli.command_modules.cosmosdb.custom import _create_database_account + + # Setup mocks + client = mock.MagicMock() + poller = mock.MagicMock() + + # Simulate successful creation + mock_created_account = mock.MagicMock() + mock_created_account.provisioning_state = "Succeeded" + poller.result.return_value = mock_created_account + client.begin_create_or_update.return_value = poller + + # Call the private function + result = _create_database_account( + client=client, + resource_group_name="rg", + account_name="myaccount", + is_restore_request=False, + arm_location="westeurope" + ) + + # Assertions + # 1. begin_create_or_update called + client.begin_create_or_update.assert_called() + # 2. poller.result() called + poller.result.assert_called() + # 3. client.get should NOT be called since result() succeeded + client.get.assert_not_called() + # 4. Result matches + self.assertEqual(result, mock_created_account) \ No newline at end of file diff --git a/src/azure-cli/azure/cli/command_modules/cosmosdb/tests/latest/test_cosmosdb_commands.py b/src/azure-cli/azure/cli/command_modules/cosmosdb/tests/latest/test_cosmosdb_commands.py index 7c0faf0e12c..58982b8241d 100644 --- a/src/azure-cli/azure/cli/command_modules/cosmosdb/tests/latest/test_cosmosdb_commands.py +++ b/src/azure-cli/azure/cli/command_modules/cosmosdb/tests/latest/test_cosmosdb_commands.py @@ -239,7 +239,13 @@ def test_locations_database_accounts(self, resource_group): assert account1['readLocations'][0]['failoverPriority'] == 1 or account1['readLocations'][1]['failoverPriority'] == 1 self.cmd('az cosmosdb failover-priority-change -n {acc} -g {rg} --failover-policies {read_location}=0 {write_location}=1') - account2 = self.cmd('az cosmosdb show -n {acc} -g {rg}').get_output_in_json() + import time + for _ in range(0, 10): + account2 = self.cmd('az cosmosdb show -n {acc} -g {rg}').get_output_in_json() + if account2['writeLocations'][0]['locationName'] == "West US": + break + time.sleep(5) + assert len(account2['writeLocations']) == 1 assert len(account2['readLocations']) == 2 @@ -260,7 +266,7 @@ def test_locations_database_accounts_offline(self, resource_group): 'read_location': read_location }) - account_pre_offline = self.cmd('az cosmosdb create -n {acc} -g {rg} --locations regionName={write_location} failoverPriority=0 --locations regionName={read_location} failoverPriority=1').get_output_in_json() + account_pre_offline = self.cmd('az cosmosdb create -n {acc} -g {rg} --enable-automatic-failover --locations regionName={write_location} failoverPriority=0 --locations regionName={read_location} failoverPriority=1').get_output_in_json() assert account_pre_offline['writeLocations'][0]['locationName'] == "East US"