Skip to content

Commit 700eed6

Browse files
committed
made tests more forgiving for score values
1 parent 34ca677 commit 700eed6

File tree

2 files changed

+68
-68
lines changed

2 files changed

+68
-68
lines changed

packages/google-cloud-firestore/tests/system/pipeline_e2e/search.yaml

Lines changed: 43 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,6 @@ tests:
4343
offset: 1
4444
query_enhancement: disabled
4545
language_code: en
46-
assert_results:
47-
- name: El Sol Tacos
48-
description: A vibrant street-side taco stand serving up quick, delicious, and
49-
traditional Mexican street food.
50-
location: GEOPOINT(39.6952, -105.0274)
51-
menu: <h3>Tacos ($3.50 each)</h3><ul><li>Al Pastor</li><li>Carne Asada</li><li>Pollo
52-
Asado</li><li>Nopales (Cactus)</li></ul><h3>Beverages</h3><ul><li>Horchata -
53-
$4</li><li>Mexican Coke - $3</li></ul>
54-
average_price_per_person: 12
5546
assert_proto:
5647
pipeline:
5748
stages:
@@ -75,6 +66,7 @@ tests:
7566
args:
7667
- stringValue: tacos
7768
name: document_matches
69+
assert_count: 1
7870
- description: search_stage_with_sort_and_add_fields
7971
pipeline:
8072
- Collection: restaurants
@@ -89,38 +81,31 @@ tests:
8981
- AliasedExpression:
9082
- Score: []
9183
- search_score
84+
- Select:
85+
- name
86+
- search_score
9287
assert_results_approximate:
93-
- name: Eastside Cantina
94-
description: Authentic street tacos and hand-shaken margaritas on the vibrant
95-
east side of the city.
96-
location: GEOPOINT(39.735, -104.885)
97-
menu: <h3>Tacos</h3><ul><li>Carnitas Tacos - $4</li><li>Barbacoa Tacos - $4.50</li><li>Shrimp
98-
Tacos - $5</li></ul><h3>Drinks</h3><ul><li>House Margarita - $9</li><li>Jarritos
99-
- $3</li></ul>
100-
average_price_per_person: 18
101-
search_score: 0.5137
102-
- name: El Sol Tacos
103-
description: A vibrant street-side taco stand serving up quick, delicious, and
104-
traditional Mexican street food.
105-
location: GEOPOINT(39.6952, -105.0274)
106-
menu: <h3>Tacos ($3.50 each)</h3><ul><li>Al Pastor</li><li>Carne Asada</li><li>Pollo
107-
Asado</li><li>Nopales (Cactus)</li></ul><h3>Beverages</h3><ul><li>Horchata -
108-
$4</li><li>Mexican Coke - $3</li></ul>
109-
average_price_per_person: 12
110-
search_score: 0.4731
88+
config:
89+
# be flexible in score values, but should be > 0
90+
absolute_tolerance: 0.99
91+
data:
92+
- name: Eastside Cantina
93+
search_score: 1.0
94+
- name: El Sol Tacos
95+
search_score: 1.0
11196
assert_proto:
11297
pipeline:
11398
stages:
114-
- args:
99+
- name: collection
100+
args:
115101
- referenceValue: /restaurants
116-
name: collection
117102
- name: search
118103
options:
119104
query:
120105
functionValue:
106+
name: document_matches
121107
args:
122108
- stringValue: tacos
123-
name: document_matches
124109
sort:
125110
arrayValue:
126111
values:
@@ -137,6 +122,14 @@ tests:
137122
search_score:
138123
functionValue:
139124
name: score
125+
- name: select
126+
args:
127+
- mapValue:
128+
fields:
129+
name:
130+
fieldReferenceValue: name
131+
search_score:
132+
fieldReferenceValue: search_score
140133
- description: expression_between
141134
pipeline:
142135
- Collection: restaurants
@@ -255,28 +248,31 @@ tests:
255248
- AliasedExpression:
256249
- Score: []
257250
- search_score
258-
assert_results_approximate:
259-
- search_score: 0.5137
260-
- search_score: 0.4732
261251
assert_proto:
262252
pipeline:
263253
stages:
264-
- args:
254+
- name: collection
255+
args:
265256
- referenceValue: /restaurants
266-
name: collection
267257
- name: search
268258
options:
269259
query:
270260
functionValue:
261+
name: document_matches
271262
args:
272263
- stringValue: tacos
273-
name: document_matches
274264
select:
275265
mapValue:
276266
fields:
277267
search_score:
278268
functionValue:
279269
name: score
270+
assert_results_approximate:
271+
config:
272+
absolute_tolerance: 0.99
273+
data:
274+
- search_score: 1.0
275+
- search_score: 1.0
280276
- description: search_full_document
281277
pipeline:
282278
- Collection: restaurants
@@ -458,36 +454,39 @@ tests:
458454
- Select:
459455
- name
460456
- searchScore
461-
assert_results_approximate:
462-
- name: The Golden Waffle
463-
searchScore: 0.5919
464457
assert_proto:
465458
pipeline:
466459
stages:
467-
- args:
460+
- name: collection
461+
args:
468462
- referenceValue: /restaurants
469-
name: collection
470463
- name: search
471464
options:
472465
query:
473466
functionValue:
467+
name: document_matches
474468
args:
475469
- stringValue: waffles
476-
name: document_matches
477470
add_fields:
478471
mapValue:
479472
fields:
480473
searchScore:
481474
functionValue:
482475
name: score
483-
- args:
476+
- name: select
477+
args:
484478
- mapValue:
485479
fields:
486480
name:
487481
fieldReferenceValue: name
488482
searchScore:
489483
fieldReferenceValue: searchScore
490-
name: select
484+
assert_results_approximate:
485+
config:
486+
absolute_tolerance: 0.99
487+
data:
488+
- name: The Golden Waffle
489+
searchScore: 1.0
491490
- description: search_sort_by_score
492491
pipeline:
493492
- Collection: restaurants

packages/google-cloud-firestore/tests/system/test_pipeline_acceptance.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,29 @@ def test_pipeline_expected_errors(test_dict, client):
133133
assert match, f"error '{found_error}' does not match '{error_regex}'"
134134

135135

136+
137+
def _assert_pipeline_results(got_results, expected_results, expected_approximate_results, expected_count):
138+
if expected_results:
139+
assert got_results == expected_results
140+
if expected_approximate_results is not None:
141+
tolerance = 1e-4
142+
if isinstance(expected_approximate_results, dict) and "data" in expected_approximate_results:
143+
if "config" in expected_approximate_results and "absolute_tolerance" in expected_approximate_results["config"]:
144+
tolerance = expected_approximate_results["config"]["absolute_tolerance"]
145+
expected_approximate_results = expected_approximate_results["data"]
146+
147+
assert len(got_results) == len(expected_approximate_results), (
148+
"got unexpected result count"
149+
)
150+
for idx in range(len(got_results)):
151+
expected = expected_approximate_results[idx]
152+
assert got_results[idx] == pytest.approx(
153+
expected, abs=tolerance
154+
)
155+
if expected_count is not None:
156+
assert len(got_results) == expected_count
157+
158+
136159
@pytest.mark.parametrize(
137160
"test_dict",
138161
[
@@ -158,18 +181,7 @@ def test_pipeline_results(test_dict, client):
158181
pipeline = parse_pipeline(client, test_dict["pipeline"])
159182
# check if server responds as expected
160183
got_results = [snapshot.data() for snapshot in pipeline.stream()]
161-
if expected_results:
162-
assert got_results == expected_results
163-
if expected_approximate_results:
164-
assert len(got_results) == len(expected_approximate_results), (
165-
"got unexpected result count"
166-
)
167-
for idx in range(len(got_results)):
168-
assert got_results[idx] == pytest.approx(
169-
expected_approximate_results[idx], abs=1e-4
170-
)
171-
if expected_count is not None:
172-
assert len(got_results) == expected_count
184+
_assert_pipeline_results(got_results, expected_results, expected_approximate_results, expected_count)
173185
if expected_end_state:
174186
for doc_path, expected_content in expected_end_state.items():
175187
doc_ref = client.document(doc_path)
@@ -231,18 +243,7 @@ async def test_pipeline_results_async(test_dict, async_client):
231243
pipeline = parse_pipeline(async_client, test_dict["pipeline"])
232244
# check if server responds as expected
233245
got_results = [snapshot.data() async for snapshot in pipeline.stream()]
234-
if expected_results:
235-
assert got_results == expected_results
236-
if expected_approximate_results:
237-
assert len(got_results) == len(expected_approximate_results), (
238-
"got unexpected result count"
239-
)
240-
for idx in range(len(got_results)):
241-
assert got_results[idx] == pytest.approx(
242-
expected_approximate_results[idx], abs=1e-4
243-
)
244-
if expected_count is not None:
245-
assert len(got_results) == expected_count
246+
_assert_pipeline_results(got_results, expected_results, expected_approximate_results, expected_count)
246247
if expected_end_state:
247248
for doc_path, expected_content in expected_end_state.items():
248249
doc_ref = async_client.document(doc_path)

0 commit comments

Comments
 (0)