Skip to content

Commit 32b4db0

Browse files
committed
use subtests for verify_pipeline
1 parent af55183 commit 32b4db0

3 files changed

Lines changed: 158 additions & 144 deletions

File tree

packages/google-cloud-firestore/noxfile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
SYSTEM_TEST_PYTHON_VERSIONS: List[str] = ALL_PYTHON
7272
SYSTEM_TEST_STANDARD_DEPENDENCIES = [
7373
"mock",
74-
"pytest",
74+
"pytest>9.0",
7575
"google-cloud-testutils",
7676
]
7777
SYSTEM_TEST_EXTERNAL_DEPENDENCIES: List[str] = [

packages/google-cloud-firestore/tests/system/test_system.py

Lines changed: 81 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -83,69 +83,73 @@ def cleanup():
8383
for operation in operations:
8484
operation()
8585

86-
87-
def verify_pipeline(query):
86+
@pytest.fixture
87+
def verify_pipeline(subtests):
8888
"""
89-
This function ensures a pipeline produces the same
90-
results as the query it is derived from
89+
This fixture provide a subtest function which
90+
ensures a pipeline produces the same results as the query it is derived
91+
from
9192
9293
It can be attached to existing query tests to check both
9394
modalities at the same time
9495
9596
Pipelines are only supported on enterprise dbs. Skip other environments
9697
"""
97-
from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery
98-
99-
client = query._client
100-
if FIRESTORE_EMULATOR:
101-
print("skip pipeline verification on emulator")
102-
return
103-
if client._database != FIRESTORE_ENTERPRISE_DB:
104-
print("pipelines only supports enterprise db")
105-
return
106-
107-
def _clean_results(results):
108-
if isinstance(results, dict):
109-
return {k: _clean_results(v) for k, v in results.items()}
110-
elif isinstance(results, list):
111-
return [_clean_results(r) for r in results]
112-
elif isinstance(results, float) and math.isnan(results):
113-
return "__NAN_VALUE__"
114-
else:
115-
return results
11698

117-
query_exception = None
118-
query_results = None
119-
try:
120-
try:
121-
if isinstance(query, BaseAggregationQuery):
122-
# aggregation queries return a list of lists of aggregation results
123-
query_results = _clean_results(
124-
list(
125-
itertools.chain.from_iterable(
126-
[[a._to_dict() for a in s] for s in query.get()]
99+
def _verifier(query):
100+
from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery
101+
with subtests.test(msg="verify_pipeline"):
102+
103+
client = query._client
104+
if FIRESTORE_EMULATOR:
105+
pytest.skip("skip pipeline verification on emulator")
106+
if client._database != FIRESTORE_ENTERPRISE_DB:
107+
pytest.skip("pipelines only supports enterprise db")
108+
109+
def _clean_results(results):
110+
if isinstance(results, dict):
111+
return {k: _clean_results(v) for k, v in results.items()}
112+
elif isinstance(results, list):
113+
return [_clean_results(r) for r in results]
114+
elif isinstance(results, float) and math.isnan(results):
115+
return "__NAN_VALUE__"
116+
else:
117+
return results
118+
119+
query_exception = None
120+
query_results = None
121+
try:
122+
try:
123+
if isinstance(query, BaseAggregationQuery):
124+
# aggregation queries return a list of lists of aggregation results
125+
query_results = _clean_results(
126+
list(
127+
itertools.chain.from_iterable(
128+
[[a._to_dict() for a in s] for s in query.get()]
129+
)
130+
)
127131
)
128-
)
129-
)
130-
else:
131-
# other qureies return a simple list of results
132-
query_results = _clean_results([s.to_dict() for s in query.get()])
133-
except Exception as e:
134-
# if we expect the query to fail, capture the exception
135-
query_exception = e
136-
pipeline = client.pipeline().create_from(query)
137-
if query_exception:
138-
# ensure that the pipeline uses same error as query
139-
with pytest.raises(query_exception.__class__):
140-
pipeline.execute()
141-
else:
142-
# ensure results match query
143-
pipeline_results = _clean_results([s.data() for s in pipeline.execute()])
144-
assert query_results == pipeline_results
145-
except FailedPrecondition as e:
146-
# if testing against a non-enterprise db, skip this check
147-
if ENTERPRISE_MODE_ERROR not in e.message:
148-
raise e
132+
else:
133+
# other qureies return a simple list of results
134+
query_results = _clean_results([s.to_dict() for s in query.get()])
135+
except Exception as e:
136+
# if we expect the query to fail, capture the exception
137+
query_exception = e
138+
pipeline = client.pipeline().create_from(query)
139+
if query_exception:
140+
# ensure that the pipeline uses same error as query
141+
with pytest.raises(query_exception.__class__):
142+
pipeline.execute()
143+
else:
144+
# ensure results match query
145+
pipeline_results = _clean_results([s.data() for s in pipeline.execute()])
146+
assert query_results == pipeline_results
147+
except FailedPrecondition as e:
148+
# if testing against a non-enterprise db, skip this check
149+
if ENTERPRISE_MODE_ERROR not in e.message:
150+
raise e
151+
152+
return _verifier
149153

150154

151155
@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True)
@@ -1300,7 +1304,7 @@ def query(collection):
13001304

13011305

13021306
@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
1303-
def test_query_stream_legacy_where(query_docs, database):
1307+
def test_query_stream_legacy_where(query_docs, database, verify_pipeline):
13041308
"""Assert the legacy code still works and returns value"""
13051309
collection, stored, allowed_vals = query_docs
13061310
with pytest.warns(
@@ -1317,7 +1321,7 @@ def test_query_stream_legacy_where(query_docs, database):
13171321

13181322

13191323
@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
1320-
def test_query_stream_w_simple_field_eq_op(query_docs, database):
1324+
def test_query_stream_w_simple_field_eq_op(query_docs, database, verify_pipeline):
13211325
collection, stored, allowed_vals = query_docs
13221326
query = collection.where(filter=FieldFilter("a", "==", 1))
13231327
values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()}
@@ -1329,7 +1333,7 @@ def test_query_stream_w_simple_field_eq_op(query_docs, database):
13291333

13301334

13311335
@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
1332-
def test_query_stream_w_simple_field_array_contains_op(query_docs, database):
1336+
def test_query_stream_w_simple_field_array_contains_op(query_docs, database, verify_pipeline):
13331337
collection, stored, allowed_vals = query_docs
13341338
query = collection.where(filter=FieldFilter("c", "array_contains", 1))
13351339
values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()}
@@ -1341,7 +1345,7 @@ def test_query_stream_w_simple_field_array_contains_op(query_docs, database):
13411345

13421346

13431347
@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
1344-
def test_query_stream_w_simple_field_in_op(query_docs, database):
1348+
def test_query_stream_w_simple_field_in_op(query_docs, database, verify_pipeline):
13451349
collection, stored, allowed_vals = query_docs
13461350
num_vals = len(allowed_vals)
13471351
query = collection.where(filter=FieldFilter("a", "in", [1, num_vals + 100]))
@@ -1354,7 +1358,7 @@ def test_query_stream_w_simple_field_in_op(query_docs, database):
13541358

13551359

13561360
@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
1357-
def test_query_stream_w_not_eq_op(query_docs, database):
1361+
def test_query_stream_w_not_eq_op(query_docs, database, verify_pipeline):
13581362
collection, stored, allowed_vals = query_docs
13591363
query = collection.where(filter=FieldFilter("stats.sum", "!=", 4))
13601364
values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()}
@@ -1377,7 +1381,7 @@ def test_query_stream_w_not_eq_op(query_docs, database):
13771381

13781382

13791383
@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
1380-
def test_query_stream_w_simple_not_in_op(query_docs, database):
1384+
def test_query_stream_w_simple_not_in_op(query_docs, database, verify_pipeline):
13811385
collection, stored, allowed_vals = query_docs
13821386
num_vals = len(allowed_vals)
13831387
query = collection.where(
@@ -1390,7 +1394,7 @@ def test_query_stream_w_simple_not_in_op(query_docs, database):
13901394

13911395

13921396
@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
1393-
def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database):
1397+
def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database, verify_pipeline):
13941398
collection, stored, allowed_vals = query_docs
13951399
num_vals = len(allowed_vals)
13961400
query = collection.where(
@@ -1405,7 +1409,7 @@ def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database)
14051409

14061410

14071411
@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
1408-
def test_query_stream_w_order_by(query_docs, database):
1412+
def test_query_stream_w_order_by(query_docs, database, verify_pipeline):
14091413
collection, stored, allowed_vals = query_docs
14101414
query = collection.order_by("b", direction=firestore.Query.DESCENDING)
14111415
values = [(snapshot.id, snapshot.to_dict()) for snapshot in query.stream()]
@@ -1420,7 +1424,7 @@ def test_query_stream_w_order_by(query_docs, database):
14201424

14211425

14221426
@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
1423-
def test_query_stream_w_field_path(query_docs, database):
1427+
def test_query_stream_w_field_path(query_docs, database, verify_pipeline):
14241428
collection, stored, allowed_vals = query_docs
14251429
query = collection.where(filter=FieldFilter("stats.sum", ">", 4))
14261430
values = {snapshot.id: snapshot.to_dict() for snapshot in query.stream()}
@@ -1459,7 +1463,7 @@ def test_query_stream_w_start_end_cursor(query_docs, database):
14591463

14601464

14611465
@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
1462-
def test_query_stream_wo_results(query_docs, database):
1466+
def test_query_stream_wo_results(query_docs, database, verify_pipeline):
14631467
collection, stored, allowed_vals = query_docs
14641468
num_vals = len(allowed_vals)
14651469
query = collection.where(filter=FieldFilter("b", "==", num_vals + 100))
@@ -1486,7 +1490,7 @@ def test_query_stream_w_projection(query_docs, database):
14861490

14871491

14881492
@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
1489-
def test_query_stream_w_multiple_filters(query_docs, database):
1493+
def test_query_stream_w_multiple_filters(query_docs, database, verify_pipeline):
14901494
collection, stored, allowed_vals = query_docs
14911495
query = collection.where(filter=FieldFilter("stats.product", ">", 5)).where(
14921496
filter=FieldFilter("stats.product", "<", 10)
@@ -1507,7 +1511,7 @@ def test_query_stream_w_multiple_filters(query_docs, database):
15071511

15081512

15091513
@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
1510-
def test_query_stream_w_offset(query_docs, database):
1514+
def test_query_stream_w_offset(query_docs, database, verify_pipeline):
15111515
collection, stored, allowed_vals = query_docs
15121516
num_vals = len(allowed_vals)
15131517
offset = 3
@@ -1528,7 +1532,7 @@ def test_query_stream_w_offset(query_docs, database):
15281532
)
15291533
@pytest.mark.parametrize("method", ["stream", "get"])
15301534
@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
1531-
def test_query_stream_or_get_w_no_explain_options(query_docs, database, method):
1535+
def test_query_stream_or_get_w_no_explain_options(query_docs, database, method, verify_pipeline):
15321536
from google.cloud.firestore_v1.query_profile import QueryExplainError
15331537

15341538
collection, _, allowed_vals = query_docs
@@ -1892,7 +1896,7 @@ def test_query_with_order_dot_key(client, cleanup, database):
18921896

18931897

18941898
@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True)
1895-
def test_query_unary(client, cleanup, database):
1899+
def test_query_unary(client, cleanup, database, verify_pipeline):
18961900
collection_name = "unary" + UNIQUE_RESOURCE_ID
18971901
collection = client.collection(collection_name)
18981902
field_name = "foo"
@@ -1949,7 +1953,7 @@ def test_query_unary(client, cleanup, database):
19491953

19501954

19511955
@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
1952-
def test_collection_group_queries(client, cleanup, database):
1956+
def test_collection_group_queries(client, cleanup, database, verify_pipeline):
19531957
collection_group = "b" + UNIQUE_RESOURCE_ID
19541958

19551959
doc_paths = [
@@ -2026,7 +2030,7 @@ def test_collection_group_queries_startat_endat(client, cleanup, database):
20262030

20272031

20282032
@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
2029-
def test_collection_group_queries_filters(client, cleanup, database):
2033+
def test_collection_group_queries_filters(client, cleanup, database, verify_pipeline):
20302034
collection_group = "b" + UNIQUE_RESOURCE_ID
20312035

20322036
doc_paths = [
@@ -2817,7 +2821,7 @@ def on_snapshot(docs, changes, read_time):
28172821

28182822

28192823
@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
2820-
def test_repro_429(client, cleanup, database):
2824+
def test_repro_429(client, cleanup, database, verify_pipeline):
28212825
# See: https://github.com/googleapis/python-firestore/issues/429
28222826
now = datetime.datetime.now(tz=datetime.timezone.utc)
28232827
collection = client.collection("repro-429" + UNIQUE_RESOURCE_ID)
@@ -3412,7 +3416,7 @@ def test_aggregation_query_stream_or_get_w_explain_options_analyze_false(
34123416

34133417

34143418
@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
3415-
def test_query_with_and_composite_filter(collection, database):
3419+
def test_query_with_and_composite_filter(collection, database, verify_pipeline):
34163420
and_filter = And(
34173421
filters=[
34183422
FieldFilter("stats.product", ">", 5),
@@ -3428,7 +3432,7 @@ def test_query_with_and_composite_filter(collection, database):
34283432

34293433

34303434
@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
3431-
def test_query_with_or_composite_filter(collection, database):
3435+
def test_query_with_or_composite_filter(collection, database, verify_pipeline):
34323436
or_filter = Or(
34333437
filters=[
34343438
FieldFilter("stats.product", ">", 5),
@@ -3462,6 +3466,7 @@ def test_aggregation_queries_with_read_time(
34623466
database,
34633467
aggregation_type,
34643468
expected_value,
3469+
verify_pipeline,
34653470
):
34663471
"""
34673472
Ensure that all aggregation queries work when read_time is passed into
@@ -3500,7 +3505,7 @@ def test_aggregation_queries_with_read_time(
35003505

35013506

35023507
@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
3503-
def test_query_with_complex_composite_filter(collection, database):
3508+
def test_query_with_complex_composite_filter(collection, database, verify_pipeline):
35043509
field_filter = FieldFilter("b", "==", 0)
35053510
or_filter = Or(
35063511
filters=[FieldFilter("stats.sum", "==", 0), FieldFilter("stats.sum", "==", 4)]
@@ -3558,6 +3563,7 @@ def test_aggregation_query_in_transaction(
35583563
aggregation_type,
35593564
aggregation_args,
35603565
expected,
3566+
verify_pipeline,
35613567
):
35623568
"""
35633569
Test creating an aggregation query inside a transaction
@@ -3599,7 +3605,7 @@ def in_transaction(transaction):
35993605

36003606

36013607
@pytest.mark.parametrize("database", TEST_DATABASES_W_ENTERPRISE, indirect=True)
3602-
def test_or_query_in_transaction(client, cleanup, database):
3608+
def test_or_query_in_transaction(client, cleanup, database, verify_pipeline):
36033609
"""
36043610
Test running or query inside a transaction. Should pass transaction id along with request
36053611
"""

0 commit comments

Comments
 (0)