@@ -83,69 +83,73 @@ def cleanup():
8383 for operation in operations :
8484 operation ()
8585
86-
87- def verify_pipeline (query ):
86+ @ pytest . fixture
87+ def verify_pipeline (subtests ):
8888 """
89- This function ensures a pipeline produces the same
90- results as the query it is derived from
89+ This fixture provide a subtest function which
90+ ensures a pipeline produces the same results as the query it is derived
91+ from
9192
9293 It can be attached to existing query tests to check both
9394 modalities at the same time
9495
9596 Pipelines are only supported on enterprise dbs. Skip other environments
9697 """
97- from google .cloud .firestore_v1 .base_aggregation import BaseAggregationQuery
98-
99- client = query ._client
100- if FIRESTORE_EMULATOR :
101- print ("skip pipeline verification on emulator" )
102- return
103- if client ._database != FIRESTORE_ENTERPRISE_DB :
104- print ("pipelines only supports enterprise db" )
105- return
106-
107- def _clean_results (results ):
108- if isinstance (results , dict ):
109- return {k : _clean_results (v ) for k , v in results .items ()}
110- elif isinstance (results , list ):
111- return [_clean_results (r ) for r in results ]
112- elif isinstance (results , float ) and math .isnan (results ):
113- return "__NAN_VALUE__"
114- else :
115- return results
11698
117- query_exception = None
118- query_results = None
119- try :
120- try :
121- if isinstance (query , BaseAggregationQuery ):
122- # aggregation queries return a list of lists of aggregation results
123- query_results = _clean_results (
124- list (
125- itertools .chain .from_iterable (
126- [[a ._to_dict () for a in s ] for s in query .get ()]
99+ def _verifier (query ):
100+ from google .cloud .firestore_v1 .base_aggregation import BaseAggregationQuery
101+ with subtests .test (msg = "verify_pipeline" ):
102+
103+ client = query ._client
104+ if FIRESTORE_EMULATOR :
105+ pytest .skip ("skip pipeline verification on emulator" )
106+ if client ._database != FIRESTORE_ENTERPRISE_DB :
107+ pytest .skip ("pipelines only supports enterprise db" )
108+
109+ def _clean_results (results ):
110+ if isinstance (results , dict ):
111+ return {k : _clean_results (v ) for k , v in results .items ()}
112+ elif isinstance (results , list ):
113+ return [_clean_results (r ) for r in results ]
114+ elif isinstance (results , float ) and math .isnan (results ):
115+ return "__NAN_VALUE__"
116+ else :
117+ return results
118+
119+ query_exception = None
120+ query_results = None
121+ try :
122+ try :
123+ if isinstance (query , BaseAggregationQuery ):
124+ # aggregation queries return a list of lists of aggregation results
125+ query_results = _clean_results (
126+ list (
127+ itertools .chain .from_iterable (
128+ [[a ._to_dict () for a in s ] for s in query .get ()]
129+ )
130+ )
127131 )
128- )
129- )
130- else :
131- # other qureies return a simple list of results
132- query_results = _clean_results ([ s . to_dict () for s in query . get ()])
133- except Exception as e :
134- # if we expect the query to fail, capture the exception
135- query_exception = e
136- pipeline = client . pipeline (). create_from ( query )
137- if query_exception :
138- # ensure that the pipeline uses same error as query
139- with pytest . raises ( query_exception . __class__ ) :
140- pipeline . execute ()
141- else :
142- # ensure results match query
143- pipeline_results = _clean_results ([ s . data () for s in pipeline . execute ()])
144- assert query_results == pipeline_results
145- except FailedPrecondition as e :
146- # if testing against a non-enterprise db, skip this check
147- if ENTERPRISE_MODE_ERROR not in e . message :
148- raise e
132+ else :
133+ # other qureies return a simple list of results
134+ query_results = _clean_results ([ s . to_dict () for s in query . get ()])
135+ except Exception as e :
136+ # if we expect the query to fail, capture the exception
137+ query_exception = e
138+ pipeline = client . pipeline (). create_from ( query )
139+ if query_exception :
140+ # ensure that the pipeline uses same error as query
141+ with pytest . raises ( query_exception . __class__ ) :
142+ pipeline . execute ()
143+ else :
144+ # ensure results match query
145+ pipeline_results = _clean_results ([ s . data () for s in pipeline . execute ()])
146+ assert query_results == pipeline_results
147+ except FailedPrecondition as e :
148+ # if testing against a non-enterprise db, skip this check
149+ if ENTERPRISE_MODE_ERROR not in e . message :
150+ raise e
151+
152+ return _verifier
149153
150154
151155@pytest .mark .parametrize ("database" , TEST_DATABASES , indirect = True )
@@ -1300,7 +1304,7 @@ def query(collection):
13001304
13011305
13021306@pytest .mark .parametrize ("database" , TEST_DATABASES_W_ENTERPRISE , indirect = True )
1303- def test_query_stream_legacy_where (query_docs , database ):
1307+ def test_query_stream_legacy_where (query_docs , database , verify_pipeline ):
13041308 """Assert the legacy code still works and returns value"""
13051309 collection , stored , allowed_vals = query_docs
13061310 with pytest .warns (
@@ -1317,7 +1321,7 @@ def test_query_stream_legacy_where(query_docs, database):
13171321
13181322
13191323@pytest .mark .parametrize ("database" , TEST_DATABASES_W_ENTERPRISE , indirect = True )
1320- def test_query_stream_w_simple_field_eq_op (query_docs , database ):
1324+ def test_query_stream_w_simple_field_eq_op (query_docs , database , verify_pipeline ):
13211325 collection , stored , allowed_vals = query_docs
13221326 query = collection .where (filter = FieldFilter ("a" , "==" , 1 ))
13231327 values = {snapshot .id : snapshot .to_dict () for snapshot in query .stream ()}
@@ -1329,7 +1333,7 @@ def test_query_stream_w_simple_field_eq_op(query_docs, database):
13291333
13301334
13311335@pytest .mark .parametrize ("database" , TEST_DATABASES_W_ENTERPRISE , indirect = True )
1332- def test_query_stream_w_simple_field_array_contains_op (query_docs , database ):
1336+ def test_query_stream_w_simple_field_array_contains_op (query_docs , database , verify_pipeline ):
13331337 collection , stored , allowed_vals = query_docs
13341338 query = collection .where (filter = FieldFilter ("c" , "array_contains" , 1 ))
13351339 values = {snapshot .id : snapshot .to_dict () for snapshot in query .stream ()}
@@ -1341,7 +1345,7 @@ def test_query_stream_w_simple_field_array_contains_op(query_docs, database):
13411345
13421346
13431347@pytest .mark .parametrize ("database" , TEST_DATABASES_W_ENTERPRISE , indirect = True )
1344- def test_query_stream_w_simple_field_in_op (query_docs , database ):
1348+ def test_query_stream_w_simple_field_in_op (query_docs , database , verify_pipeline ):
13451349 collection , stored , allowed_vals = query_docs
13461350 num_vals = len (allowed_vals )
13471351 query = collection .where (filter = FieldFilter ("a" , "in" , [1 , num_vals + 100 ]))
@@ -1354,7 +1358,7 @@ def test_query_stream_w_simple_field_in_op(query_docs, database):
13541358
13551359
13561360@pytest .mark .parametrize ("database" , TEST_DATABASES_W_ENTERPRISE , indirect = True )
1357- def test_query_stream_w_not_eq_op (query_docs , database ):
1361+ def test_query_stream_w_not_eq_op (query_docs , database , verify_pipeline ):
13581362 collection , stored , allowed_vals = query_docs
13591363 query = collection .where (filter = FieldFilter ("stats.sum" , "!=" , 4 ))
13601364 values = {snapshot .id : snapshot .to_dict () for snapshot in query .stream ()}
@@ -1377,7 +1381,7 @@ def test_query_stream_w_not_eq_op(query_docs, database):
13771381
13781382
13791383@pytest .mark .parametrize ("database" , TEST_DATABASES_W_ENTERPRISE , indirect = True )
1380- def test_query_stream_w_simple_not_in_op (query_docs , database ):
1384+ def test_query_stream_w_simple_not_in_op (query_docs , database , verify_pipeline ):
13811385 collection , stored , allowed_vals = query_docs
13821386 num_vals = len (allowed_vals )
13831387 query = collection .where (
@@ -1390,7 +1394,7 @@ def test_query_stream_w_simple_not_in_op(query_docs, database):
13901394
13911395
13921396@pytest .mark .parametrize ("database" , TEST_DATABASES_W_ENTERPRISE , indirect = True )
1393- def test_query_stream_w_simple_field_array_contains_any_op (query_docs , database ):
1397+ def test_query_stream_w_simple_field_array_contains_any_op (query_docs , database , verify_pipeline ):
13941398 collection , stored , allowed_vals = query_docs
13951399 num_vals = len (allowed_vals )
13961400 query = collection .where (
@@ -1405,7 +1409,7 @@ def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database)
14051409
14061410
14071411@pytest .mark .parametrize ("database" , TEST_DATABASES_W_ENTERPRISE , indirect = True )
1408- def test_query_stream_w_order_by (query_docs , database ):
1412+ def test_query_stream_w_order_by (query_docs , database , verify_pipeline ):
14091413 collection , stored , allowed_vals = query_docs
14101414 query = collection .order_by ("b" , direction = firestore .Query .DESCENDING )
14111415 values = [(snapshot .id , snapshot .to_dict ()) for snapshot in query .stream ()]
@@ -1420,7 +1424,7 @@ def test_query_stream_w_order_by(query_docs, database):
14201424
14211425
14221426@pytest .mark .parametrize ("database" , TEST_DATABASES_W_ENTERPRISE , indirect = True )
1423- def test_query_stream_w_field_path (query_docs , database ):
1427+ def test_query_stream_w_field_path (query_docs , database , verify_pipeline ):
14241428 collection , stored , allowed_vals = query_docs
14251429 query = collection .where (filter = FieldFilter ("stats.sum" , ">" , 4 ))
14261430 values = {snapshot .id : snapshot .to_dict () for snapshot in query .stream ()}
@@ -1459,7 +1463,7 @@ def test_query_stream_w_start_end_cursor(query_docs, database):
14591463
14601464
14611465@pytest .mark .parametrize ("database" , TEST_DATABASES_W_ENTERPRISE , indirect = True )
1462- def test_query_stream_wo_results (query_docs , database ):
1466+ def test_query_stream_wo_results (query_docs , database , verify_pipeline ):
14631467 collection , stored , allowed_vals = query_docs
14641468 num_vals = len (allowed_vals )
14651469 query = collection .where (filter = FieldFilter ("b" , "==" , num_vals + 100 ))
@@ -1486,7 +1490,7 @@ def test_query_stream_w_projection(query_docs, database):
14861490
14871491
14881492@pytest .mark .parametrize ("database" , TEST_DATABASES_W_ENTERPRISE , indirect = True )
1489- def test_query_stream_w_multiple_filters (query_docs , database ):
1493+ def test_query_stream_w_multiple_filters (query_docs , database , verify_pipeline ):
14901494 collection , stored , allowed_vals = query_docs
14911495 query = collection .where (filter = FieldFilter ("stats.product" , ">" , 5 )).where (
14921496 filter = FieldFilter ("stats.product" , "<" , 10 )
@@ -1507,7 +1511,7 @@ def test_query_stream_w_multiple_filters(query_docs, database):
15071511
15081512
15091513@pytest .mark .parametrize ("database" , TEST_DATABASES_W_ENTERPRISE , indirect = True )
1510- def test_query_stream_w_offset (query_docs , database ):
1514+ def test_query_stream_w_offset (query_docs , database , verify_pipeline ):
15111515 collection , stored , allowed_vals = query_docs
15121516 num_vals = len (allowed_vals )
15131517 offset = 3
@@ -1528,7 +1532,7 @@ def test_query_stream_w_offset(query_docs, database):
15281532)
15291533@pytest .mark .parametrize ("method" , ["stream" , "get" ])
15301534@pytest .mark .parametrize ("database" , TEST_DATABASES_W_ENTERPRISE , indirect = True )
1531- def test_query_stream_or_get_w_no_explain_options (query_docs , database , method ):
1535+ def test_query_stream_or_get_w_no_explain_options (query_docs , database , method , verify_pipeline ):
15321536 from google .cloud .firestore_v1 .query_profile import QueryExplainError
15331537
15341538 collection , _ , allowed_vals = query_docs
@@ -1892,7 +1896,7 @@ def test_query_with_order_dot_key(client, cleanup, database):
18921896
18931897
18941898@pytest .mark .parametrize ("database" , TEST_DATABASES , indirect = True )
1895- def test_query_unary (client , cleanup , database ):
1899+ def test_query_unary (client , cleanup , database , verify_pipeline ):
18961900 collection_name = "unary" + UNIQUE_RESOURCE_ID
18971901 collection = client .collection (collection_name )
18981902 field_name = "foo"
@@ -1949,7 +1953,7 @@ def test_query_unary(client, cleanup, database):
19491953
19501954
19511955@pytest .mark .parametrize ("database" , TEST_DATABASES_W_ENTERPRISE , indirect = True )
1952- def test_collection_group_queries (client , cleanup , database ):
1956+ def test_collection_group_queries (client , cleanup , database , verify_pipeline ):
19531957 collection_group = "b" + UNIQUE_RESOURCE_ID
19541958
19551959 doc_paths = [
@@ -2026,7 +2030,7 @@ def test_collection_group_queries_startat_endat(client, cleanup, database):
20262030
20272031
20282032@pytest .mark .parametrize ("database" , TEST_DATABASES_W_ENTERPRISE , indirect = True )
2029- def test_collection_group_queries_filters (client , cleanup , database ):
2033+ def test_collection_group_queries_filters (client , cleanup , database , verify_pipeline ):
20302034 collection_group = "b" + UNIQUE_RESOURCE_ID
20312035
20322036 doc_paths = [
@@ -2817,7 +2821,7 @@ def on_snapshot(docs, changes, read_time):
28172821
28182822
28192823@pytest .mark .parametrize ("database" , TEST_DATABASES_W_ENTERPRISE , indirect = True )
2820- def test_repro_429 (client , cleanup , database ):
2824+ def test_repro_429 (client , cleanup , database , verify_pipeline ):
28212825 # See: https://github.com/googleapis/python-firestore/issues/429
28222826 now = datetime .datetime .now (tz = datetime .timezone .utc )
28232827 collection = client .collection ("repro-429" + UNIQUE_RESOURCE_ID )
@@ -3412,7 +3416,7 @@ def test_aggregation_query_stream_or_get_w_explain_options_analyze_false(
34123416
34133417
34143418@pytest .mark .parametrize ("database" , TEST_DATABASES_W_ENTERPRISE , indirect = True )
3415- def test_query_with_and_composite_filter (collection , database ):
3419+ def test_query_with_and_composite_filter (collection , database , verify_pipeline ):
34163420 and_filter = And (
34173421 filters = [
34183422 FieldFilter ("stats.product" , ">" , 5 ),
@@ -3428,7 +3432,7 @@ def test_query_with_and_composite_filter(collection, database):
34283432
34293433
34303434@pytest .mark .parametrize ("database" , TEST_DATABASES_W_ENTERPRISE , indirect = True )
3431- def test_query_with_or_composite_filter (collection , database ):
3435+ def test_query_with_or_composite_filter (collection , database , verify_pipeline ):
34323436 or_filter = Or (
34333437 filters = [
34343438 FieldFilter ("stats.product" , ">" , 5 ),
@@ -3462,6 +3466,7 @@ def test_aggregation_queries_with_read_time(
34623466 database ,
34633467 aggregation_type ,
34643468 expected_value ,
3469+ verify_pipeline ,
34653470):
34663471 """
34673472 Ensure that all aggregation queries work when read_time is passed into
@@ -3500,7 +3505,7 @@ def test_aggregation_queries_with_read_time(
35003505
35013506
35023507@pytest .mark .parametrize ("database" , TEST_DATABASES_W_ENTERPRISE , indirect = True )
3503- def test_query_with_complex_composite_filter (collection , database ):
3508+ def test_query_with_complex_composite_filter (collection , database , verify_pipeline ):
35043509 field_filter = FieldFilter ("b" , "==" , 0 )
35053510 or_filter = Or (
35063511 filters = [FieldFilter ("stats.sum" , "==" , 0 ), FieldFilter ("stats.sum" , "==" , 4 )]
@@ -3558,6 +3563,7 @@ def test_aggregation_query_in_transaction(
35583563 aggregation_type ,
35593564 aggregation_args ,
35603565 expected ,
3566+ verify_pipeline ,
35613567):
35623568 """
35633569 Test creating an aggregation query inside a transaction
@@ -3599,7 +3605,7 @@ def in_transaction(transaction):
35993605
36003606
36013607@pytest .mark .parametrize ("database" , TEST_DATABASES_W_ENTERPRISE , indirect = True )
3602- def test_or_query_in_transaction (client , cleanup , database ):
3608+ def test_or_query_in_transaction (client , cleanup , database , verify_pipeline ):
36033609 """
36043610 Test running or query inside a transaction. Should pass transaction id along with request
36053611 """
0 commit comments