Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos/pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ markers =
cosmosEmulator: marks tests as depending in Cosmos DB Emulator.
cosmosLong: marks tests to be run on a Cosmos DB live account.
cosmosQuery: marks tests running queries on Cosmos DB live account.
cosmosAAD: marks tests running data-plane operations with AAD auth on a Cosmos DB live account.
cosmosSplit: marks test where there are partition splits on CosmosDB live account.
cosmosMultiRegion: marks tests running on a Cosmos DB live account with multi-region and multi-write enabled.
cosmosCircuitBreaker: marks tests running on Cosmos DB live account with per partition circuit breaker enabled and multi-write enabled.
Expand Down
42 changes: 40 additions & 2 deletions sdk/cosmos/azure-cosmos/tests/test_aad.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from azure.cosmos import DatabaseProxy, ContainerProxy, exceptions
from azure.core.exceptions import HttpResponseError



def _remove_padding(encoded_string):
while encoded_string.endswith("="):
encoded_string = encoded_string[0:len(encoded_string) - 1]
Expand All @@ -35,7 +37,7 @@ def get_test_item(num):

class CosmosEmulatorCredential(object):
def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
# type: (*str, **object) -> AccessToken
"""Request an access token for the emulator. Based on Azure Core's Access Token Credential.

This method is called automatically by Azure SDK clients.
Expand Down Expand Up @@ -86,18 +88,50 @@ def get_token(self, *scopes, **kwargs):


@pytest.mark.cosmosEmulator
@pytest.mark.skipif(
not test_config.TestConfig.is_emulator
and test_config.TestConfig.data_auth_mode != 'aad',
reason="On a live account, run this file with COSMOS_TEST_DATA_AUTH_MODE=aad "
"so the dual-client factory returns the AAD branch. Otherwise the "
"test would silently use the master key and `delete_database` would "
"succeed instead of returning the asserted 403.",
)
class TestAAD(unittest.TestCase):
client: cosmos_client.CosmosClient = None
database: DatabaseProxy = None
container: ContainerProxy = None
configs = test_config.TestConfig
host = configs.host
masterKey = configs.masterKey
# Emulator-only: the hand-crafted JWT lets us exercise the AAD code path
# against the local emulator (which has no real AAD endpoint). On a live
# account this attribute is unused; setUpClass routes through the dual-client
# factory below instead.
credential = CosmosEmulatorCredential() if configs.is_emulator else configs.credential
_skip_scope_tests_on_non_emulator = pytest.mark.skipif(
not configs.is_emulator,
reason="Scope capture tests are emulator-specific (localhost audience)."
)

@classmethod
def setUpClass(cls):
cls.client = cosmos_client.CosmosClient(cls.host, cls.credential)
# Two construction paths:
#
# * Emulator runs (`pytest -m cosmosEmulator`): build the client
# directly with `CosmosEmulatorCredential` so the AAD JWT-parsing
# code path is exercised against the local emulator. The dual-client
# factory cannot do this — on the emulator it returns the master-key
# client and bypasses AAD entirely.
#
# * Live runs (`pytest -m cosmosAAD` on the AAD lane, or any live
# run with `COSMOS_TEST_DATA_AUTH_MODE=aad`): go through
# `TestConfig.create_data_client()` so this test exercises the
# same dual-client factory contract every other AAD-tagged test
# uses.
if cls.configs.is_emulator:
cls.client = cosmos_client.CosmosClient(cls.host, cls.credential)
else:
cls.client = test_config.TestConfig.create_data_client()
cls.database = cls.client.get_database_client(cls.configs.TEST_DATABASE_ID)
cls.container = cls.database.get_container_client(cls.configs.TEST_SINGLE_PARTITION_CONTAINER_ID)

Expand Down Expand Up @@ -133,6 +167,7 @@ def capturing_get_token(self, *scopes, **kwargs):
credential_cls.get_token = original_get_token
return scopes_captured, result

@_skip_scope_tests_on_non_emulator
def test_override_scope_no_fallback(self):
"""When override scope is provided, only that scope is used and no fallback occurs."""
override_scope = "https://my.custom.scope/.default"
Expand All @@ -156,6 +191,7 @@ def action(scopes_captured):
except Exception:
pass

@_skip_scope_tests_on_non_emulator
def test_override_scope_auth_error_no_fallback(self):
"""When override scope is provided and auth fails, no fallback to other scopes occurs."""
override_scope = "https://my.custom.scope/.default"
Expand All @@ -180,6 +216,7 @@ def action(scopes_captured):
finally:
del os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"]

@_skip_scope_tests_on_non_emulator
def test_account_scope_only(self):
"""When account scope is provided, only that scope is used."""
account_scope = "https://localhost/.default"
Expand All @@ -203,6 +240,7 @@ def action(scopes_captured):
except Exception:
pass

@_skip_scope_tests_on_non_emulator
def test_account_scope_fallback_on_error(self):
"""When account scope is provided and auth fails, fallback to default scope occurs."""
account_scope = "https://localhost/.default"
Expand Down
12 changes: 11 additions & 1 deletion sdk/cosmos/azure-cosmos/tests/test_aad_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from azure.cosmos.aio import CosmosClient, DatabaseProxy, ContainerProxy
from azure.core.exceptions import HttpResponseError



def _remove_padding(encoded_string):
while encoded_string.endswith("="):
encoded_string = encoded_string[0:len(encoded_string) - 1]
Expand All @@ -35,7 +37,7 @@ def get_test_item(num):

class CosmosEmulatorCredential(object):
async def get_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> AccessToken
# type: (*str, **object) -> AccessToken
"""Request an access token for the emulator. Based on Azure Core's Access Token Credential.

This method is called automatically by Azure SDK clients.
Expand Down Expand Up @@ -94,6 +96,10 @@ class TestAADAsync(unittest.IsolatedAsyncioTestCase):
host = configs.host
masterKey = configs.masterKey
credential = CosmosEmulatorCredential() if configs.is_emulator else configs.credential_async
_skip_scope_tests_on_non_emulator = pytest.mark.skipif(
not configs.is_emulator,
reason="Scope capture tests are emulator-specific (localhost audience)."
)

@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -146,6 +152,7 @@ async def capturing_get_token(self, *scopes, **kwargs):
finally:
credential_cls.get_token = orig_get_token

@_skip_scope_tests_on_non_emulator
async def test_override_scope_no_fallback_async(self):
"""When override scope is provided, only that scope is used and no fallback occurs."""
override_scope = "https://my.custom.scope/.default"
Expand All @@ -172,6 +179,7 @@ async def action(scopes_captured):
except Exception:
pass

@_skip_scope_tests_on_non_emulator
async def test_override_scope_no_fallback_on_error_async(self):
"""When override scope is provided and auth fails, no fallback occurs."""
override_scope = "https://my.custom.scope/.default"
Expand Down Expand Up @@ -205,6 +213,7 @@ async def action(scopes_captured):
except Exception:
pass

@_skip_scope_tests_on_non_emulator
async def test_account_scope_only_async(self):
"""When account scope is provided, only that scope is used."""
account_scope = "https://localhost/.default"
Expand All @@ -230,6 +239,7 @@ async def action(scopes_captured):
except Exception:
pass

@_skip_scope_tests_on_non_emulator
async def test_account_scope_fallback_on_error_async(self):
"""When account scope is provided and auth fails, fallback to default scope occurs."""
account_scope = "https://localhost/.default"
Expand Down
11 changes: 7 additions & 4 deletions sdk/cosmos/azure-cosmos/tests/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ class _config:


@pytest.mark.cosmosQuery
# @pytest.mark.cosmosAAD # TEMP: disabled to validate AAD pipeline using only test_aad.py
class TestAggregateQuery(unittest.TestCase):
client: cosmos_client.CosmosClient = None
key_client: cosmos_client.CosmosClient = None

@classmethod
def setUpClass(cls):
Expand All @@ -40,7 +42,7 @@ def setUpClass(cls):
@classmethod
def tearDownClass(cls) -> None:
try:
cls.created_db.delete_container(cls.created_collection.id)
cls.key_db.delete_container(cls.created_collection.id)
except CosmosHttpResponseError:
pass

Expand All @@ -52,9 +54,10 @@ def _setup(cls):
"'masterKey' and 'host' at the top of this class to run the "
"tests.")

cls.client = cosmos_client.CosmosClient(_config.host, _config.master_key)
cls.created_db = cls.client.get_database_client(test_config.TestConfig.TEST_DATABASE_ID)
cls.created_collection = cls._create_collection(cls.created_db)
cls.key_client, cls.key_db, cls.client, cls.created_db = (
test_config.TestConfig.create_test_clients(test_config.TestConfig.TEST_DATABASE_ID))
created_collection_ref = cls._create_collection(cls.key_db)
cls.created_collection = cls.created_db.get_container_client(created_collection_ref.id)

# test documents
document_definitions = []
Expand Down
18 changes: 10 additions & 8 deletions sdk/cosmos/azure-cosmos/tests/test_auto_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
@pytest.mark.cosmosLong
class TestAutoScale(unittest.TestCase):
client: CosmosClient = None
key_client: CosmosClient = None
host = test_config.TestConfig.host
masterKey = test_config.TestConfig.masterKey
connectionPolicy = test_config.TestConfig.connectionPolicy
Expand All @@ -27,8 +28,9 @@ def setUpClass(cls):
"'masterKey' and 'host' at the top of this class to run the "
"tests.")

cls.client = cosmos_client.CosmosClient(cls.host, cls.masterKey)
cls.created_database = cls.client.get_database_client(test_config.TestConfig.TEST_DATABASE_ID)
cls.key_client = cosmos_client.CosmosClient(cls.host, cls.masterKey)
cls.client = test_config.TestConfig.create_data_client()
cls.created_database = cls.key_client.get_database_client(test_config.TestConfig.TEST_DATABASE_ID)

def test_autoscale_create_container(self):
container_id = None
Expand Down Expand Up @@ -75,7 +77,7 @@ def test_autoscale_create_database(self):
database_id = "db_auto_scale_" + str(uuid.uuid4())
try:
# Testing auto_scale_settings for the create_database method
created_database = self.client.create_database(database_id, offer_throughput=ThroughputProperties(
created_database = self.key_client.create_database(database_id, offer_throughput=ThroughputProperties(
auto_scale_max_throughput=5000,
auto_scale_increment_percent=2))
created_db_properties = created_database.get_throughput()
Expand All @@ -84,11 +86,11 @@ def test_autoscale_create_database(self):
# Testing the input value of the increment_percentage
assert created_db_properties.auto_scale_increment_percent == 2

self.client.delete_database(created_database.id)
self.key_client.delete_database(created_database.id)

# Testing auto_scale_settings for the create_database_if_not_exists method
database_id = "db_auto_scale_2_" + str(uuid.uuid4())
created_database = self.client.create_database_if_not_exists(database_id,
created_database = self.key_client.create_database_if_not_exists(database_id,
offer_throughput=ThroughputProperties(
auto_scale_max_throughput=9000,
auto_scale_increment_percent=11))
Expand All @@ -98,13 +100,13 @@ def test_autoscale_create_database(self):
# Testing the input value of the increment_percentage
assert created_db_properties.auto_scale_increment_percent == 11
finally:
self.client.delete_database(database_id)
self.key_client.delete_database(database_id)

def test_autoscale_replace_throughput(self):
database_id = "replace_db" + str(uuid.uuid4())
container_id = None
try:
created_database = self.client.create_database(database_id, offer_throughput=ThroughputProperties(
created_database = self.key_client.create_database(database_id, offer_throughput=ThroughputProperties(
auto_scale_max_throughput=5000,
auto_scale_increment_percent=2))
created_database.replace_throughput(
Expand All @@ -114,7 +116,7 @@ def test_autoscale_replace_throughput(self):
assert created_db_properties.auto_scale_max_throughput == 7000
# Testing the input value of the increment_percentage
assert created_db_properties.auto_scale_increment_percent == 20
self.client.delete_database(database_id)
self.key_client.delete_database(database_id)

container_id = "container_with_auto_scale_settings" + str(uuid.uuid4())
created_container = self.created_database.create_container(
Expand Down
19 changes: 11 additions & 8 deletions sdk/cosmos/azure-cosmos/tests/test_auto_scale_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class TestAutoScaleAsync(unittest.IsolatedAsyncioTestCase):
masterKey = test_config.TestConfig.masterKey
connectionPolicy = test_config.TestConfig.connectionPolicy

key_client: CosmosClient = None
client: CosmosClient = None
created_database: DatabaseProxy = None

Expand All @@ -32,10 +33,12 @@ def setUpClass(cls):
"tests.")

async def asyncSetUp(self):
self.client = CosmosClient(self.host, self.masterKey)
self.created_database = self.client.get_database_client(self.TEST_DATABASE_ID)
self.key_client = CosmosClient(self.host, self.masterKey)
self.client = test_config.TestConfig.create_data_client_async()
self.created_database = self.key_client.get_database_client(self.TEST_DATABASE_ID)

async def asyncTearDown(self):
await self.key_client.close()
await self.client.close()

async def test_autoscale_create_container_async(self):
Expand Down Expand Up @@ -80,7 +83,7 @@ async def test_autoscale_create_database_async(self):
try:
# Testing auto_scale_settings for the create_database method
database_id = "db1_" + str(uuid.uuid4())
created_database = await self.client.create_database(database_id, offer_throughput=ThroughputProperties(
created_database = await self.key_client.create_database(database_id, offer_throughput=ThroughputProperties(
auto_scale_max_throughput=5000,
auto_scale_increment_percent=0))
created_db_properties = await created_database.get_throughput()
Expand All @@ -89,11 +92,11 @@ async def test_autoscale_create_database_async(self):
# Testing the input value of the increment_percentage
assert created_db_properties.auto_scale_increment_percent == 0

await self.client.delete_database(created_database.id)
await self.key_client.delete_database(created_database.id)

# Testing auto_scale_settings for the create_database_if_not_exists method
database_id = "db2_" + str(uuid.uuid4())
created_database = await self.client.create_database_if_not_exists(database_id, offer_throughput=ThroughputProperties(
created_database = await self.key_client.create_database_if_not_exists(database_id, offer_throughput=ThroughputProperties(
auto_scale_max_throughput=9000,
auto_scale_increment_percent=11))
created_db_properties = await created_database.get_throughput()
Expand All @@ -102,13 +105,13 @@ async def test_autoscale_create_database_async(self):
# Testing the input value of the increment_percentage
assert created_db_properties.auto_scale_increment_percent == 11
finally:
await self.client.delete_database(database_id)
await self.key_client.delete_database(database_id)

async def test_replace_throughput_async(self):
database_id = "replace_db" + str(uuid.uuid4())
container_id = None
try:
created_database = await self.client.create_database(database_id, offer_throughput=ThroughputProperties(
created_database = await self.key_client.create_database(database_id, offer_throughput=ThroughputProperties(
auto_scale_max_throughput=5000,
auto_scale_increment_percent=0))
await created_database.replace_throughput(
Expand All @@ -118,7 +121,7 @@ async def test_replace_throughput_async(self):
assert created_db_properties.auto_scale_max_throughput == 7000
# Testing the replaced value of the increment_percentage
assert created_db_properties.auto_scale_increment_percent == 20
await self.client.delete_database(database_id)
await self.key_client.delete_database(database_id)

container_id = "container_with_auto_scale_settings" + str(uuid.uuid4())
created_container = await self.created_database.create_container(
Expand Down
24 changes: 13 additions & 11 deletions sdk/cosmos/azure-cosmos/tests/test_availability_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def _get_operation_type(test_operation_type: str) -> str:
raise ValueError("invalid operationType")

@pytest.mark.cosmosMultiRegion
# @pytest.mark.cosmosAAD # TEMP: disabled to validate AAD pipeline using only test_aad.py
class TestAvailabilityStrategy:
host = test_config.TestConfig.host
master_key = test_config.TestConfig.masterKey
Expand All @@ -278,7 +279,7 @@ def setup_class(cls):
logger.addHandler(cls.MOCK_HANDLER)
logger.setLevel(logging.DEBUG)

cls.client_without_fault = CosmosClient(cls.host, cls.master_key)
cls.client_without_fault = test_config.TestConfig.create_data_client()
database_account = cls.client_without_fault.get_database_account()
cls.write_locations = [loc["name"] for loc in database_account._WritableLocations]
cls.read_locations = [loc["name"] for loc in database_account._ReadableLocations]
Expand All @@ -292,8 +293,7 @@ def setup_method(self):

def _setup_method_with_custom_transport(self, custom_transport, default_endpoint=None, retry_write=False, **kwargs):
"""Initialize test client with optional custom transport and endpoint"""
if default_endpoint is None:
default_endpoint = self.host
endpoint = default_endpoint or self.host

# Set preferred locations with write locations first
preferred_locations = self.write_locations + [loc for loc in self.read_locations if loc not in self.write_locations]
Expand All @@ -302,14 +302,16 @@ def _setup_method_with_custom_transport(self, custom_transport, default_endpoint
if not container_id:
container_id = self.TEST_CONTAINER_MULTI_PARTITION_ID

client = CosmosClient(
default_endpoint,
self.master_key,
preferred_locations=preferred_locations,
transport=custom_transport,
retry_write=retry_write,
**kwargs
)
client_kwargs = {
"preferred_locations": preferred_locations,
"transport": custom_transport,
"retry_write": retry_write,
**kwargs,
}
if endpoint != self.host:
client = CosmosClient(endpoint, self.master_key, **client_kwargs)
else:
client = test_config.TestConfig.create_data_client(**client_kwargs)
db = client.get_database_client(self.TEST_DATABASE_ID)
container = db.get_container_client(container_id)
return {"client": client, "db": db, "col": container}
Expand Down
Loading
Loading