Skip to content

Commit 750452c

Browse files
committed
initial commit
1 parent 943a979 commit 750452c

8 files changed

Lines changed: 312 additions & 2 deletions

File tree

packages/google-cloud-firestore/google/cloud/firestore_v1/base_pipeline.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,20 @@ def sort(self, *orders: stages.Ordering) -> "_BasePipeline":
344344
"""
345345
return self._append(stages.Sort(*orders))
346346

347+
def search(self, options: stages.SearchOptions) -> "_BasePipeline":
348+
"""
349+
Adds a search stage to the pipeline.
350+
351+
This stage filters documents based on the provided query expression.
352+
353+
Args:
354+
options: A SearchOptions instance configuring the search.
355+
356+
Returns:
357+
A new Pipeline object with this stage appended to the stage list
358+
"""
359+
return self._append(stages.Search(options))
360+
347361
def sample(self, limit_or_options: int | stages.SampleOptions) -> "_BasePipeline":
348362
"""
349363
Performs a pseudo-random sampling of the documents from the previous stage.

packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_expressions.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,50 @@ def less_than_or_equal(
609609
[self, self._cast_to_expr_or_convert_to_constant(other)],
610610
)
611611

612+
@expose_as_static
613+
def between(
614+
self, lower: Expression | CONSTANT_TYPE, upper: Expression | CONSTANT_TYPE
615+
) -> "BooleanExpression":
616+
"""Creates an expression that checks if this expression is between two values.
617+
618+
Example:
619+
>>> # Check if the 'age' field is between 18 and 65
620+
>>> Field.of("age").between(18, 65)
621+
622+
Args:
623+
lower: The lower bound expression or constant value.
624+
upper: The upper bound expression or constant value.
625+
626+
Returns:
627+
A new `BooleanExpression` representing the between comparison.
628+
"""
629+
return BooleanExpression(
630+
"between",
631+
[
632+
self,
633+
self._cast_to_expr_or_convert_to_constant(lower),
634+
self._cast_to_expr_or_convert_to_constant(upper),
635+
],
636+
)
637+
638+
@expose_as_static
639+
def geo_distance(self, other: Expression | CONSTANT_TYPE) -> "FunctionExpression":
640+
"""Creates an expression that calculates the distance between two geographical points.
641+
642+
Example:
643+
>>> # Calculate distance between the 'location' field and a target GeoPoint
644+
>>> Field.of("location").geo_distance(target_point)
645+
646+
Args:
647+
other: The other point expression or constant value.
648+
649+
Returns:
650+
A new `FunctionExpression` representing the distance.
651+
"""
652+
return FunctionExpression(
653+
"geo_distance", [self, self._cast_to_expr_or_convert_to_constant(other)]
654+
)
655+
612656
@expose_as_static
613657
def equal_any(
614658
self, array: Array | Sequence[Expression | CONSTANT_TYPE] | Expression
@@ -2634,3 +2678,22 @@ class Rand(FunctionExpression):
26342678

26352679
def __init__(self):
26362680
super().__init__("rand", [], use_infix_repr=False)
2681+
2682+
2683+
def document_matches(query: Expression | str) -> BooleanExpression:
2684+
"""Creates a boolean expression for a document match query.
2685+
2686+
Example:
2687+
>>> # Find documents matching the query string
2688+
>>> document_matches("search query")
2689+
2690+
Args:
2691+
query: The search query string or expression.
2692+
2693+
Returns:
2694+
A new `BooleanExpression` representing the document match.
2695+
"""
2696+
return BooleanExpression(
2697+
"document_matches", [Expression._cast_to_expr_or_convert_to_constant(query)]
2698+
)
2699+

packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_stages.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,85 @@ def percentage(value: float):
109109
return SampleOptions(value, mode=SampleOptions.Mode.PERCENT)
110110

111111

112+
class QueryEnhancement(Enum):
113+
"""Define the query expansion behavior used by full-text search expressions."""
114+
DISABLED = "disabled"
115+
REQUIRED = "required"
116+
PREFERRED = "preferred"
117+
118+
119+
class SearchOptions:
120+
"""Options for configuring the `Search` pipeline stage."""
121+
122+
def __init__(
123+
self,
124+
query: str | BooleanExpression,
125+
limit: Optional[int] = None,
126+
retrieval_depth: Optional[int] = None,
127+
sort: Optional[Sequence[Ordering] | Ordering] = None,
128+
add_fields: Optional[Sequence[Selectable]] = None,
129+
select: Optional[Sequence[Selectable | str]] = None,
130+
offset: Optional[int] = None,
131+
query_enhancement: Optional[str | QueryEnhancement] = None,
132+
language_code: Optional[str] = None,
133+
):
134+
"""
135+
Initializes a SearchOptions instance.
136+
137+
Args:
138+
query (str | BooleanExpression): Specifies the search query that will be used to query and score documents
139+
by the search stage. The query can be expressed as an `Expression`, which will be used to score
140+
and filter the results. Not all expressions supported by Pipelines are supported in the Search query.
141+
The query can also be expressed as a string in the Search DSL.
142+
limit (Optional[int]): The maximum number of documents to return from the Search stage.
143+
retrieval_depth (Optional[int]): The maximum number of documents for the search stage to score. Documents
144+
will be processed in the pre-sort order specified by the search index.
145+
sort (Optional[Sequence[Ordering] | Ordering]): Orderings specify how the input documents are sorted.
146+
add_fields (Optional[Sequence[Selectable]]): The fields to add to each document, specified as a `Selectable`.
147+
select (Optional[Sequence[Selectable | str]]): The fields to keep or add to each document,
148+
specified as an array of `Selectable` or strings.
149+
offset (Optional[int]): The number of documents to skip.
150+
query_enhancement (Optional[str | QueryEnhancement]): Define the query expansion behavior used by full-text search expressions
151+
in this search stage.
152+
language_code (Optional[str]): The BCP-47 language code of text in the search query, such as "en-US" or "sr-Latn".
153+
"""
154+
if isinstance(query, str):
155+
from google.cloud.firestore_v1.pipeline_expressions import document_matches
156+
self.query = document_matches(query)
157+
else:
158+
self.query = query
159+
self.limit = limit
160+
self.retrieval_depth = retrieval_depth
161+
self.sort = [sort] if isinstance(sort, Ordering) else sort
162+
self.add_fields = add_fields
163+
self.select = select
164+
self.offset = offset
165+
self.query_enhancement = (
166+
QueryEnhancement(query_enhancement.lower()) if isinstance(query_enhancement, str) else query_enhancement
167+
)
168+
self.language_code = language_code
169+
170+
def __repr__(self):
171+
args = [f"query={self.query!r}"]
172+
if self.limit is not None:
173+
args.append(f"limit={self.limit}")
174+
if self.retrieval_depth is not None:
175+
args.append(f"retrieval_depth={self.retrieval_depth}")
176+
if self.sort is not None:
177+
args.append(f"sort={self.sort}")
178+
if self.add_fields is not None:
179+
args.append(f"add_fields={self.add_fields}")
180+
if self.select is not None:
181+
args.append(f"select={self.select}")
182+
if self.offset is not None:
183+
args.append(f"offset={self.offset}")
184+
if self.query_enhancement is not None:
185+
args.append(f"query_enhancement={self.query_enhancement!r}")
186+
if self.language_code is not None:
187+
args.append(f"language_code={self.language_code!r}")
188+
return f"{self.__class__.__name__}({', '.join(args)})"
189+
190+
112191
class UnnestOptions:
113192
"""Options for configuring the `Unnest` pipeline stage.
114193
@@ -423,6 +502,41 @@ def _pb_args(self):
423502
]
424503

425504

505+
class Search(Stage):
506+
"""Search stage."""
507+
508+
def __init__(self, options: SearchOptions):
509+
super().__init__("search")
510+
self.options = options
511+
512+
def _pb_args(self) -> list[Value]:
513+
return []
514+
515+
def _pb_options(self) -> dict[str, Value]:
516+
options = {}
517+
if self.options.query is not None:
518+
options["query"] = self.options.query._to_pb()
519+
if self.options.limit is not None:
520+
options["limit"] = Value(integer_value=self.options.limit)
521+
if self.options.retrieval_depth is not None:
522+
options["retrieval_depth"] = Value(integer_value=self.options.retrieval_depth)
523+
if self.options.sort is not None:
524+
options["sort"] = Value(array_value={"values": [s._to_pb() for s in self.options.sort]})
525+
if self.options.add_fields is not None:
526+
from google.cloud.firestore_v1.pipeline_expressions import Selectable
527+
options["add_fields"] = Selectable._to_value(self.options.add_fields)
528+
if self.options.select is not None:
529+
from google.cloud.firestore_v1.pipeline_expressions import Selectable
530+
options["select"] = Selectable._to_value(self.options.select)
531+
if self.options.offset is not None:
532+
options["offset"] = Value(integer_value=self.options.offset)
533+
if self.options.query_enhancement is not None:
534+
options["query_enhancement"] = Value(string_value=self.options.query_enhancement.value)
535+
if self.options.language_code is not None:
536+
options["language_code"] = Value(string_value=self.options.language_code)
537+
return options
538+
539+
426540
class Select(Stage):
427541
"""Selects or creates a set of fields."""
428542

packages/google-cloud-firestore/tests/system/pipeline_e2e/data.yaml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,4 +144,14 @@ data:
144144
doc_with_nan:
145145
value: "NaN"
146146
doc_with_null:
147-
value: null
147+
value: null
148+
geopoints:
149+
loc1:
150+
name: SF
151+
location: GEOPOINT(37.7749,-122.4194)
152+
loc2:
153+
name: LA
154+
location: GEOPOINT(34.0522,-118.2437)
155+
loc3:
156+
name: NY
157+
location: GEOPOINT(40.7128,-74.0060)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
tests:
2+
- description: search_stage_basic
3+
pipeline:
4+
- Collection: books
5+
- Search:
6+
- SearchOptions:
7+
query: "technology"
8+
limit: 2
9+
assert_proto:
10+
pipeline:
11+
stages:
12+
- args:
13+
- referenceValue: /books
14+
name: collection
15+
- args: []
16+
name: search
17+
options:
18+
limit:
19+
integerValue: '2'
20+
query:
21+
functionValue:
22+
args:
23+
- stringValue: "technology"
24+
name: document_matches

packages/google-cloud-firestore/tests/system/test_pipeline_acceptance.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from google.cloud.firestore_v1 import pipeline_expressions as expr
3434
from google.cloud.firestore_v1 import pipeline_stages as stages
3535
from google.cloud.firestore_v1.vector import Vector
36+
from google.cloud.firestore_v1 import GeoPoint
3637

3738
FIRESTORE_PROJECT = os.environ.get("GCLOUD_PROJECT")
3839

@@ -343,12 +344,15 @@ def _parse_yaml_types(data):
343344
else:
344345
return [_parse_yaml_types(value) for value in data]
345346
# detect timestamps
346-
if isinstance(data, str) and ":" in data:
347+
if isinstance(data, str) and ":" in data and not data.startswith("GEOPOINT("):
347348
try:
348349
parsed_datetime = datetime.datetime.fromisoformat(data)
349350
return parsed_datetime
350351
except ValueError:
351352
pass
353+
if isinstance(data, str) and data.startswith("GEOPOINT("):
354+
parts = data[9:-1].split(",")
355+
return GeoPoint(float(parts[0]), float(parts[1]))
352356
if data == "NaN":
353357
return float("NaN")
354358
return data

packages/google-cloud-firestore/tests/unit/v1/test_pipeline_expressions.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,34 @@ def test_equal(self):
790790
infix_instance = arg1.equal(arg2)
791791
assert infix_instance == instance
792792

793+
def test_between(self):
794+
arg1 = self._make_arg("Left")
795+
arg2 = self._make_arg("Lower")
796+
arg3 = self._make_arg("Upper")
797+
instance = Expression.between(arg1, arg2, arg3)
798+
assert instance.name == "between"
799+
assert instance.params == [arg1, arg2, arg3]
800+
assert repr(instance) == "Left.between(Lower, Upper)"
801+
infix_instance = arg1.between(arg2, arg3)
802+
assert infix_instance == instance
803+
804+
def test_geo_distance(self):
805+
arg1 = self._make_arg("Left")
806+
arg2 = self._make_arg("Right")
807+
instance = Expression.geo_distance(arg1, arg2)
808+
assert instance.name == "geo_distance"
809+
assert instance.params == [arg1, arg2]
810+
assert repr(instance) == "Left.geo_distance(Right)"
811+
infix_instance = arg1.geo_distance(arg2)
812+
assert infix_instance == instance
813+
814+
def test_document_matches(self):
815+
arg1 = self._make_arg("Query")
816+
instance = expr.document_matches(arg1)
817+
assert instance.name == "document_matches"
818+
assert instance.params == [arg1]
819+
assert repr(instance) == "document_matches(Query)"
820+
793821
def test_greater_than_or_equal(self):
794822
arg1 = self._make_arg("Left")
795823
arg2 = self._make_arg("Right")
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from google.cloud.firestore_v1.pipeline_stages import SearchOptions, Search, QueryEnhancement
2+
from google.cloud.firestore_v1.pipeline_expressions import Field, Ordering, document_matches
3+
4+
def test_search_options():
5+
options = SearchOptions(
6+
query="test query",
7+
limit=10,
8+
retrieval_depth=2,
9+
sort=Ordering("score", Ordering.Direction.DESCENDING),
10+
add_fields=[Field("extra")],
11+
select=[Field("name")],
12+
offset=5,
13+
query_enhancement="disabled",
14+
language_code="en",
15+
)
16+
assert options.limit == 10
17+
assert options.retrieval_depth == 2
18+
assert len(options.sort) == 1
19+
assert options.offset == 5
20+
assert options.query_enhancement == QueryEnhancement.DISABLED
21+
assert options.language_code == "en"
22+
23+
# Check proto generation
24+
stage = Search(options)
25+
pb_opts = stage._pb_options()
26+
27+
assert pb_opts["limit"].integer_value == 10
28+
assert pb_opts["retrieval_depth"].integer_value == 2
29+
assert len(pb_opts["sort"].array_value.values) == 1
30+
assert pb_opts["offset"].integer_value == 5
31+
assert pb_opts["query_enhancement"].string_value == "disabled"
32+
assert pb_opts["language_code"].string_value == "en"
33+
34+
def test_search_options_bool_expr():
35+
options = SearchOptions(query=document_matches("query string"))
36+
stage = Search(options)
37+
pb_opts = stage._pb_options()
38+
assert "query" in pb_opts
39+
40+
def test_between():
41+
expr = Field("age").between(18, 65)
42+
assert expr.name == "between"
43+
assert len(expr.params) == 3
44+
45+
def test_geo_distance():
46+
expr = Field("location").geo_distance("other")
47+
assert expr.name == "geo_distance"
48+
assert len(expr.params) == 2
49+
50+
def test_document_matches():
51+
expr = document_matches("search query")
52+
assert expr.name == "document_matches"
53+
assert len(expr.params) == 1

0 commit comments

Comments
 (0)