Skip to content

Commit 7e47f0a

Browse files
dibahlfiscbedd
andauthored
AAD test coverage (#46568)
* AAD test coverage initial commit * commenting all AAD coverage except one as a smoke test * adding pipeline changes * adding documentation * update the naming convention such that our automation will collect it automagically * fixing CI pipeline binding * expanding coverage of AAD tests * expanding coverage of AAD tests * cleaning up resources * cleaning up resources * cleaning up resources * fixing pipeline errors * debugging pipleine * debugging pipleine * running a subset * enabling AAD at a bigger scale * AAD fixing tests * AAD fixing tests * fixing AAD tests * AAD fixing tests * AAD fixing tests * fixing AAD tests * addressing PR comments * fixing bugs * fixing Copilot comments * fixing Copilot comments * refactoring tests --------- Co-authored-by: Scott Beddall <scbedd@microsoft.com>
1 parent e9441f4 commit 7e47f0a

104 files changed

Lines changed: 4322 additions & 2322 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

sdk/cosmos/azure-cosmos/pytest.ini

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,13 @@ markers =
33
cosmosEmulator: marks tests as depending in Cosmos DB Emulator.
44
cosmosLong: marks tests to be run on a Cosmos DB live account.
55
cosmosQuery: marks tests running queries on Cosmos DB live account.
6+
cosmosAADLong: marks AAD tests for the standard live-account lane.
7+
cosmosAADSplit: marks AAD tests for partition split scenarios.
8+
cosmosAADMultiRegion: marks AAD tests for multi-region scenarios.
9+
cosmosAADCircuitBreaker: marks AAD tests for circuit-breaker scenarios.
10+
cosmosAADCircuitBreakerMultiRegion: marks AAD tests for single-master multi-region-read circuit-breaker scenarios.
11+
cosmosAADQuery: marks AAD tests for query-focused scenarios.
12+
cosmosAADPerPartitionAutomaticFailover: marks AAD tests for per-partition automatic failover scenarios.
613
cosmosSplit: marks test where there are partition splits on CosmosDB live account.
714
cosmosMultiRegion: marks tests running on a Cosmos DB live account with multi-region and multi-write enabled.
815
cosmosCircuitBreaker: marks tests running on Cosmos DB live account with per partition circuit breaker enabled and multi-write enabled.

sdk/cosmos/azure-cosmos/tests/test_aad.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313

1414
import azure.cosmos.cosmos_client as cosmos_client
1515
import test_config
16-
from azure.cosmos import DatabaseProxy, ContainerProxy, exceptions
16+
from azure.cosmos import DatabaseProxy, ContainerProxy
1717
from azure.core.exceptions import HttpResponseError
1818

19+
20+
1921
def _remove_padding(encoded_string):
2022
while encoded_string.endswith("="):
2123
encoded_string = encoded_string[0:len(encoded_string) - 1]
@@ -35,7 +37,7 @@ def get_test_item(num):
3537

3638
class CosmosEmulatorCredential(object):
3739
def get_token(self, *scopes, **kwargs):
38-
# type: (*str, **Any) -> AccessToken
40+
# type: (*str, **object) -> AccessToken
3941
"""Request an access token for the emulator. Based on Azure Core's Access Token Credential.
4042
4143
This method is called automatically by Azure SDK clients.
@@ -93,14 +95,21 @@ class TestAAD(unittest.TestCase):
9395
configs = test_config.TestConfig
9496
host = configs.host
9597
masterKey = configs.masterKey
96-
credential = CosmosEmulatorCredential() if configs.is_emulator else configs.credential
98+
# Emulator-only credential used by this class.
99+
credential = CosmosEmulatorCredential()
100+
_skip_on_non_emulator = pytest.mark.skipif(
101+
not configs.is_emulator,
102+
reason="Emulator credential tests are emulator-specific (localhost audience)."
103+
)
97104

98105
@classmethod
99106
def setUpClass(cls):
107+
# Emulator-only path: always use the emulator credential.
100108
cls.client = cosmos_client.CosmosClient(cls.host, cls.credential)
101109
cls.database = cls.client.get_database_client(cls.configs.TEST_DATABASE_ID)
102110
cls.container = cls.database.get_container_client(cls.configs.TEST_SINGLE_PARTITION_CONTAINER_ID)
103111

112+
@_skip_on_non_emulator
104113
def test_aad_credentials(self):
105114
print("Container info: " + str(self.container.read()))
106115
self.container.create_item(get_test_item(0))
@@ -110,14 +119,6 @@ def test_aad_credentials(self):
110119
print("Query result: " + str(query_results[0]))
111120
self.container.delete_item(item='Item_0', partition_key='pk')
112121

113-
# Attempting to do management operations will return a 403 Forbidden exception
114-
try:
115-
self.client.delete_database(self.configs.TEST_DATABASE_ID)
116-
except exceptions.CosmosHttpResponseError as e:
117-
assert e.status_code == 403
118-
print("403 error assertion success")
119-
120-
121122
def _run_with_scope_capture(self, credential_cls, action, *args, **kwargs):
122123
scopes_captured = []
123124
original_get_token = credential_cls.get_token
@@ -133,6 +134,7 @@ def capturing_get_token(self, *scopes, **kwargs):
133134
credential_cls.get_token = original_get_token
134135
return scopes_captured, result
135136

137+
@_skip_on_non_emulator
136138
def test_override_scope_no_fallback(self):
137139
"""When override scope is provided, only that scope is used and no fallback occurs."""
138140
override_scope = "https://my.custom.scope/.default"
@@ -156,6 +158,7 @@ def action(scopes_captured):
156158
except Exception:
157159
pass
158160

161+
@_skip_on_non_emulator
159162
def test_override_scope_auth_error_no_fallback(self):
160163
"""When override scope is provided and auth fails, no fallback to other scopes occurs."""
161164
override_scope = "https://my.custom.scope/.default"
@@ -180,6 +183,7 @@ def action(scopes_captured):
180183
finally:
181184
del os.environ["AZURE_COSMOS_AAD_SCOPE_OVERRIDE"]
182185

186+
@_skip_on_non_emulator
183187
def test_account_scope_only(self):
184188
"""When account scope is provided, only that scope is used."""
185189
account_scope = "https://localhost/.default"
@@ -203,6 +207,7 @@ def action(scopes_captured):
203207
except Exception:
204208
pass
205209

210+
@_skip_on_non_emulator
206211
def test_account_scope_fallback_on_error(self):
207212
"""When account scope is provided and auth fails, fallback to default scope occurs."""
208213
account_scope = "https://localhost/.default"

sdk/cosmos/azure-cosmos/tests/test_aad_async.py

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
from azure.core.credentials import AccessToken
1313

1414
import test_config
15-
from azure.cosmos import exceptions
1615
from azure.cosmos.aio import CosmosClient, DatabaseProxy, ContainerProxy
1716
from azure.core.exceptions import HttpResponseError
1817

18+
19+
1920
def _remove_padding(encoded_string):
2021
while encoded_string.endswith("="):
2122
encoded_string = encoded_string[0:len(encoded_string) - 1]
@@ -35,7 +36,7 @@ def get_test_item(num):
3536

3637
class CosmosEmulatorCredential(object):
3738
async def get_token(self, *scopes, **kwargs):
38-
# type: (*str, **Any) -> AccessToken
39+
# type: (*str, **object) -> AccessToken
3940
"""Request an access token for the emulator. Based on Azure Core's Access Token Credential.
4041
4142
This method is called automatically by Azure SDK clients.
@@ -93,16 +94,11 @@ class TestAADAsync(unittest.IsolatedAsyncioTestCase):
9394
configs = test_config.TestConfig
9495
host = configs.host
9596
masterKey = configs.masterKey
96-
credential = CosmosEmulatorCredential() if configs.is_emulator else configs.credential_async
97-
98-
@classmethod
99-
def setUpClass(cls):
100-
if (cls.credential == '[YOUR_KEY_HERE]' or
101-
cls.host == '[YOUR_ENDPOINT_HERE]'):
102-
raise Exception(
103-
"You must specify your Azure Cosmos account values for "
104-
"'masterKey' and 'host' at the top of this class to run the "
105-
"tests.")
97+
credential = CosmosEmulatorCredential()
98+
_skip_scope_tests_on_non_emulator = pytest.mark.skipif(
99+
not configs.is_emulator,
100+
reason="Scope capture tests are emulator-specific (localhost audience)."
101+
)
106102

107103
async def asyncSetUp(self):
108104
self.client = CosmosClient(self.host, self.credential)
@@ -112,9 +108,8 @@ async def asyncSetUp(self):
112108
async def asyncTearDown(self):
113109
await self.client.close()
114110

111+
@_skip_scope_tests_on_non_emulator
115112
async def test_aad_credentials_async(self):
116-
# Do any R/W data operations with your authorized AAD client
117-
118113
print("Container info: " + str(await self.container.read()))
119114
await self.container.create_item(get_test_item(0))
120115
print("Point read result: " + str(await self.container.read_item(item='Item_0', partition_key='pk')))
@@ -123,12 +118,6 @@ async def test_aad_credentials_async(self):
123118
print("Query result: " + str(query_results[0]))
124119
await self.container.delete_item(item='Item_0', partition_key='pk')
125120

126-
# Attempting to do management operations will return a 403 Forbidden exception
127-
try:
128-
await self.client.delete_database(self.configs.TEST_DATABASE_ID)
129-
except exceptions.CosmosHttpResponseError as e:
130-
assert e.status_code == 403
131-
print("403 error assertion success")
132121

133122
async def _run_with_scope_capture_async(self, credential_cls, action):
134123
scopes_captured = []
@@ -146,6 +135,7 @@ async def capturing_get_token(self, *scopes, **kwargs):
146135
finally:
147136
credential_cls.get_token = orig_get_token
148137

138+
@_skip_scope_tests_on_non_emulator
149139
async def test_override_scope_no_fallback_async(self):
150140
"""When override scope is provided, only that scope is used and no fallback occurs."""
151141
override_scope = "https://my.custom.scope/.default"
@@ -172,6 +162,7 @@ async def action(scopes_captured):
172162
except Exception:
173163
pass
174164

165+
@_skip_scope_tests_on_non_emulator
175166
async def test_override_scope_no_fallback_on_error_async(self):
176167
"""When override scope is provided and auth fails, no fallback occurs."""
177168
override_scope = "https://my.custom.scope/.default"
@@ -205,6 +196,7 @@ async def action(scopes_captured):
205196
except Exception:
206197
pass
207198

199+
@_skip_scope_tests_on_non_emulator
208200
async def test_account_scope_only_async(self):
209201
"""When account scope is provided, only that scope is used."""
210202
account_scope = "https://localhost/.default"
@@ -230,6 +222,7 @@ async def action(scopes_captured):
230222
except Exception:
231223
pass
232224

225+
@_skip_scope_tests_on_non_emulator
233226
async def test_account_scope_fallback_on_error_async(self):
234227
"""When account scope is provided and auth fails, fallback to default scope occurs."""
235228
account_scope = "https://localhost/.default"

sdk/cosmos/azure-cosmos/tests/test_aggregate.py

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,24 @@
1515

1616

1717
class _config:
18+
is_aad_mode = test_config.TestConfig.data_auth_mode == "aad"
1819
host = test_config.TestConfig.host
1920
master_key = test_config.TestConfig.masterKey
2021
connection_policy = test_config.TestConfig.connectionPolicy
2122
PARTITION_KEY = 'key'
2223
UNIQUE_PARTITION_KEY = 'uniquePartitionKey'
2324
FIELD = 'field'
24-
DOCUMENTS_COUNT = 400
25-
DOCS_WITH_SAME_PARTITION_KEY = 200
25+
# Keep key-auth query coverage unchanged; trim only AAD runs to stay under CI timeout.
26+
DOCUMENTS_COUNT = 120 if is_aad_mode else 400
27+
DOCS_WITH_SAME_PARTITION_KEY = 60 if is_aad_mode else 200
2628
docs_with_numeric_id = 0
2729
sum = 0
2830

2931

3032
@pytest.mark.cosmosQuery
3133
class TestAggregateQuery(unittest.TestCase):
3234
client: cosmos_client.CosmosClient = None
35+
key_client: cosmos_client.CosmosClient = None
3336

3437
@classmethod
3538
def setUpClass(cls):
@@ -40,7 +43,7 @@ def setUpClass(cls):
4043
@classmethod
4144
def tearDownClass(cls) -> None:
4245
try:
43-
cls.created_db.delete_container(cls.created_collection.id)
46+
cls.key_db.delete_container(cls.created_collection.id)
4447
except CosmosHttpResponseError:
4548
pass
4649

@@ -52,9 +55,10 @@ def _setup(cls):
5255
"'masterKey' and 'host' at the top of this class to run the "
5356
"tests.")
5457

55-
cls.client = cosmos_client.CosmosClient(_config.host, _config.master_key)
56-
cls.created_db = cls.client.get_database_client(test_config.TestConfig.TEST_DATABASE_ID)
57-
cls.created_collection = cls._create_collection(cls.created_db)
58+
cls.key_client, cls.key_db, cls.client, cls.created_db = (
59+
test_config.TestConfig.create_test_clients(test_config.TestConfig.TEST_DATABASE_ID))
60+
created_collection_ref = cls._create_collection(cls.key_db)
61+
cls.created_collection = cls.created_db.get_container_client(created_collection_ref.id)
5862

5963
# test documents
6064
document_definitions = []
@@ -138,6 +142,62 @@ def test_run_all(self):
138142
print(test_name + ': ' + query + " FAILED")
139143
raise e
140144

145+
# AAD-only smoke subset.
146+
#
147+
# Why this exists: the CI AAD lane runs on Linux and the shared
148+
# ``azpysdk.main whl --isolate`` bootstrap on that pool already eats
149+
# ~90 minutes of the 120-minute job ceiling. Running the full
150+
# ``test_run_all`` matrix (24 aggregate variants) under AAD on top of
151+
# that bootstrap pushes the lane over the ceiling. The full matrix
152+
# still runs under the ``cosmosQuery`` lane (key auth) -- this method
153+
# is *additional* AAD-only coverage focused on Contoso's exact bug
154+
# shape: cross-partition aggregate query under bearer auth, including
155+
# the ORDER BY pagination case where token refresh mid-stream is most
156+
# likely to surface.
157+
#
158+
# Three queries: cross-partition COUNT (fan-out), cross-partition SUM
159+
# with ORDER BY (fan-out + paginated reduce -> token-refresh window),
160+
# single-partition AVG (pinned-PK path).
161+
@pytest.mark.cosmosAADLong
162+
@pytest.mark.skipif(
163+
test_config.TestConfig.data_auth_mode != "aad",
164+
reason="AAD-only smoke subset; full coverage runs under cosmosQuery (key auth).",
165+
)
166+
def test_aad_aggregate_subset(self):
167+
same_partition_avg = (
168+
_config.DOCS_WITH_SAME_PARTITION_KEY * (_config.DOCS_WITH_SAME_PARTITION_KEY + 1) / 2.0
169+
) / _config.DOCS_WITH_SAME_PARTITION_KEY
170+
subset = [
171+
(
172+
"test_aad_xp_count",
173+
"SELECT VALUE COUNT(r.{}) FROM r WHERE true".format(_config.PARTITION_KEY),
174+
_config.DOCUMENTS_COUNT,
175+
),
176+
(
177+
"test_aad_xp_sum_orderby",
178+
"SELECT VALUE SUM(r.{f}) FROM r WHERE IS_NUMBER(r.{pk}) ORDER BY r.{pk}".format(
179+
f=_config.PARTITION_KEY, pk=_config.PARTITION_KEY
180+
),
181+
_config.sum,
182+
),
183+
(
184+
"test_aad_sp_avg",
185+
"SELECT VALUE AVG(r.{f}) FROM r WHERE r.{pk} = '{val}'".format(
186+
f=_config.FIELD,
187+
pk=_config.PARTITION_KEY,
188+
val=_config.UNIQUE_PARTITION_KEY,
189+
),
190+
same_partition_avg,
191+
),
192+
]
193+
for test_name, query, expected in subset:
194+
try:
195+
self._run_one(query, expected)
196+
print(test_name + ': ' + query + " PASSED", flush=True)
197+
except Exception as e:
198+
print(test_name + ': ' + query + " FAILED", flush=True)
199+
raise e
200+
141201
def _run_one(self, query, expected_result):
142202
self._execute_query_and_validate_results(self.created_collection, query, expected_result)
143203

sdk/cosmos/azure-cosmos/tests/test_auto_scale.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
@pytest.mark.cosmosLong
1515
class TestAutoScale(unittest.TestCase):
16-
client: CosmosClient = None
16+
key_client: CosmosClient = None
1717
host = test_config.TestConfig.host
1818
masterKey = test_config.TestConfig.masterKey
1919
connectionPolicy = test_config.TestConfig.connectionPolicy
@@ -27,8 +27,8 @@ def setUpClass(cls):
2727
"'masterKey' and 'host' at the top of this class to run the "
2828
"tests.")
2929

30-
cls.client = cosmos_client.CosmosClient(cls.host, cls.masterKey)
31-
cls.created_database = cls.client.get_database_client(test_config.TestConfig.TEST_DATABASE_ID)
30+
cls.key_client = cosmos_client.CosmosClient(cls.host, cls.masterKey)
31+
cls.created_database = cls.key_client.get_database_client(test_config.TestConfig.TEST_DATABASE_ID)
3232

3333
def test_autoscale_create_container(self):
3434
container_id = None
@@ -75,7 +75,7 @@ def test_autoscale_create_database(self):
7575
database_id = "db_auto_scale_" + str(uuid.uuid4())
7676
try:
7777
# Testing auto_scale_settings for the create_database method
78-
created_database = self.client.create_database(database_id, offer_throughput=ThroughputProperties(
78+
created_database = self.key_client.create_database(database_id, offer_throughput=ThroughputProperties(
7979
auto_scale_max_throughput=5000,
8080
auto_scale_increment_percent=2))
8181
created_db_properties = created_database.get_throughput()
@@ -84,11 +84,11 @@ def test_autoscale_create_database(self):
8484
# Testing the input value of the increment_percentage
8585
assert created_db_properties.auto_scale_increment_percent == 2
8686

87-
self.client.delete_database(created_database.id)
87+
self.key_client.delete_database(created_database.id)
8888

8989
# Testing auto_scale_settings for the create_database_if_not_exists method
9090
database_id = "db_auto_scale_2_" + str(uuid.uuid4())
91-
created_database = self.client.create_database_if_not_exists(database_id,
91+
created_database = self.key_client.create_database_if_not_exists(database_id,
9292
offer_throughput=ThroughputProperties(
9393
auto_scale_max_throughput=9000,
9494
auto_scale_increment_percent=11))
@@ -98,13 +98,13 @@ def test_autoscale_create_database(self):
9898
# Testing the input value of the increment_percentage
9999
assert created_db_properties.auto_scale_increment_percent == 11
100100
finally:
101-
self.client.delete_database(database_id)
101+
self.key_client.delete_database(database_id)
102102

103103
def test_autoscale_replace_throughput(self):
104104
database_id = "replace_db" + str(uuid.uuid4())
105105
container_id = None
106106
try:
107-
created_database = self.client.create_database(database_id, offer_throughput=ThroughputProperties(
107+
created_database = self.key_client.create_database(database_id, offer_throughput=ThroughputProperties(
108108
auto_scale_max_throughput=5000,
109109
auto_scale_increment_percent=2))
110110
created_database.replace_throughput(
@@ -114,7 +114,7 @@ def test_autoscale_replace_throughput(self):
114114
assert created_db_properties.auto_scale_max_throughput == 7000
115115
# Testing the input value of the increment_percentage
116116
assert created_db_properties.auto_scale_increment_percent == 20
117-
self.client.delete_database(database_id)
117+
self.key_client.delete_database(database_id)
118118

119119
container_id = "container_with_auto_scale_settings" + str(uuid.uuid4())
120120
created_container = self.created_database.create_container(

0 commit comments

Comments
 (0)