Skip to content

Commit 712ded8

Browse files
committed
allow passing string or expression into search
1 parent 4415bbb commit 712ded8

File tree

4 files changed

+27
-5
lines changed

4 files changed

+27
-5
lines changed

packages/google-cloud-firestore/google/cloud/firestore_v1/base_pipeline.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ def sort(self, *orders: stages.Ordering) -> "_BasePipeline":
394394
"""
395395
return self._append(stages.Sort(*orders))
396396

397-
def search(self, options: stages.SearchOptions) -> "_BasePipeline":
397+
def search(self, query_or_options: str | BooleanExpression | stages.SearchOptions) -> "_BasePipeline":
398398
"""
399399
Adds a search stage to the pipeline.
400400
@@ -403,7 +403,7 @@ def search(self, options: stages.SearchOptions) -> "_BasePipeline":
403403
Example:
404404
>>> from google.cloud.firestore_v1.pipeline_stages import SearchOptions
405405
>>> from google.cloud.firestore_v1.pipeline_expressions import And, DocumentMatches, Field, GeoPoint
406-
>>> # Search for restaurants matching "waffles" within 1000m of a location
406+
>>> # Search for restaurants matching either "waffles" or "pancakes" near a location
407407
>>> pipeline = client.pipeline().collection("restaurants").search(
408408
... SearchOptions(
409409
... query=And(
@@ -415,12 +415,13 @@ def search(self, options: stages.SearchOptions) -> "_BasePipeline":
415415
... )
416416
417417
Args:
418-
options: A SearchOptions instance configuring the search.
418+
options: Either a string or expression representing the search query, or
419+
A `SearchOptions` instance configuring the search.
419420
420421
Returns:
421422
A new Pipeline object with this stage appended to the stage list
422423
"""
423-
return self._append(stages.Search(options))
424+
return self._append(stages.Search(query_or_options))
424425

425426
def sample(self, limit_or_options: int | stages.SampleOptions) -> "_BasePipeline":
426427
"""

packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_stages.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,8 +500,12 @@ def _pb_args(self):
500500
class Search(Stage):
501501
"""Search stage."""
502502

503-
def __init__(self, options: SearchOptions):
503+
def __init__(self, query_or_options: str | BooleanExpression | SearchOptions):
504504
super().__init__("search")
505+
if isinstance(query_or_options, SearchOptions):
506+
options = query_or_options
507+
else:
508+
options = SearchOptions(query=query_or_options)
505509
self.options = options
506510

507511
def _pb_args(self) -> list[Value]:

packages/google-cloud-firestore/tests/unit/v1/test_pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,8 @@ def test_pipeline_execute_stream_equivalence():
403403
("replace_with", (Field.of("n"),), stages.ReplaceWith),
404404
("sort", (Field.of("n").descending(),), stages.Sort),
405405
("sort", (Field.of("n").descending(), Field.of("m").ascending()), stages.Sort),
406+
("search", ("my query",), stages.Search),
407+
("search", (stages.SearchOptions(query="my query"),), stages.Search),
406408
("sample", (10,), stages.Sample),
407409
("sample", (stages.SampleOptions.doc_limit(10),), stages.Sample),
408410
("union", (_make_pipeline(),), stages.Union),

packages/google-cloud-firestore/tests/unit/v1/test_pipeline_stages.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -827,6 +827,21 @@ def test_search_string_query_wrapping(self):
827827
assert options.query.name == "document_matches"
828828
assert options.query.params[0].value == "science"
829829

830+
def test_search_with_string(self):
831+
stage = stages.Search("technology")
832+
assert isinstance(stage.options, stages.SearchOptions)
833+
assert stage.options.query.name == "document_matches"
834+
assert stage.options.query.params[0].value == "technology"
835+
pb_opts = stage._pb_options()
836+
assert "query" in pb_opts
837+
838+
def test_search_with_boolean_expression(self):
839+
expr = DocumentMatches("tech")
840+
stage = stages.Search(expr)
841+
assert isinstance(stage.options, stages.SearchOptions)
842+
assert stage.options.query is expr
843+
pb_opts = stage._pb_options()
844+
assert "query" in pb_opts
830845

831846
class TestSelect:
832847
def _make_one(self, *args, **kwargs):

0 commit comments

Comments
 (0)