Skip to content

Commit ae55a1b

Browse files
feat(firestore): pipeline subequeries (#16470)
Add support for Pipeline subqueries, allowing users to perform complex data transformation by embedding a full pipeline inside another New Stages: - `define` - `subcollection` New Expressions: - `current_document` - `variable` - `get_field` New classes: - SubPipeline, which represents a pipeline without an associated client, which can not be executed --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 3f07f69 commit ae55a1b

File tree

18 files changed

+703
-51
lines changed

18 files changed

+703
-51
lines changed

packages/google-cloud-firestore/google/cloud/firestore_v1/async_pipeline.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ class AsyncPipeline(_BasePipeline):
6666
subject to potential breaking changes in future releases
6767
"""
6868

69+
_client: AsyncClient
70+
6971
def __init__(self, client: AsyncClient, *stages: stages.Stage):
7072
"""
7173
Initializes an asynchronous Pipeline.
@@ -102,6 +104,9 @@ async def execute(
102104
explain_metrics will be available on the returned list.
103105
additional_options (Optional[dict[str, Value | Constant]]): Additional options to pass to the query.
104106
These options will take precedence over method argument if there is a conflict (e.g. explain_options)
107+
108+
Raises:
109+
google.api_core.exceptions.GoogleAPIError: If there is a backend error.
105110
"""
106111
kwargs = {k: v for k, v in locals().items() if k != "self"}
107112
stream = AsyncPipelineStream(PipelineResult, self, **kwargs)
@@ -134,6 +139,9 @@ def stream(
134139
explain_metrics will be available on the returned generator.
135140
additional_options (Optional[dict[str, Value | Constant]]): Additional options to pass to the query.
136141
These options will take precedence over method argument if there is a conflict (e.g. explain_options)
142+
143+
Raises:
144+
google.api_core.exceptions.GoogleAPIError: If there is a backend error.
137145
"""
138146
kwargs = {k: v for k, v in locals().items() if k != "self"}
139147
return AsyncPipelineStream(PipelineResult, self, **kwargs)

packages/google-cloud-firestore/google/cloud/firestore_v1/base_pipeline.py

Lines changed: 112 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import TYPE_CHECKING, Sequence
17+
from typing import TYPE_CHECKING, Sequence, TypeVar, Type
18+
1819

1920
from google.cloud.firestore_v1 import pipeline_stages as stages
2021
from google.cloud.firestore_v1.base_vector_query import DistanceMeasure
@@ -25,6 +26,8 @@
2526
Expression,
2627
Field,
2728
Selectable,
29+
FunctionExpression,
30+
_PipelineValueExpression,
2831
)
2932
from google.cloud.firestore_v1.types.pipeline import (
3033
StructuredPipeline as StructuredPipeline_pb,
@@ -35,6 +38,8 @@
3538
from google.cloud.firestore_v1.async_client import AsyncClient
3639
from google.cloud.firestore_v1.client import Client
3740

41+
_T = TypeVar("_T", bound="_BasePipeline")
42+
3843

3944
class _BasePipeline:
4045
"""
@@ -44,7 +49,7 @@ class _BasePipeline:
4449
Use `client.pipeline()` to create pipeline instances.
4550
"""
4651

47-
def __init__(self, client: Client | AsyncClient):
52+
def __init__(self, client: Client | AsyncClient | None):
4853
"""
4954
Initializes a new pipeline.
5055
@@ -59,8 +64,8 @@ def __init__(self, client: Client | AsyncClient):
5964

6065
@classmethod
6166
def _create_with_stages(
62-
cls, client: Client | AsyncClient, *stages
63-
) -> _BasePipeline:
67+
cls: Type[_T], client: Client | AsyncClient | None, *stages
68+
) -> _T:
6469
"""
6570
Initializes a new pipeline with the given stages.
6671
@@ -90,6 +95,51 @@ def _to_pb(self, **options) -> StructuredPipeline_pb:
9095
options=options,
9196
)
9297

98+
def to_array_expression(self) -> Expression:
99+
"""
100+
Converts this Pipeline into an expression that evaluates to an array of results.
101+
Used for embedding 1:N subqueries into stages like `addFields`.
102+
103+
Example:
104+
>>> # Get a list of all reviewer names for each book
105+
>>> db.pipeline().collection("books").define(Field.of("id").as_("book_id")).add_fields(
106+
... db.pipeline().collection("reviews")
107+
... .where(Field.of("book_id").equal(Variable("book_id")))
108+
... .select(Field.of("reviewer").as_("name"))
109+
... .to_array_expression().as_("reviewers")
110+
... )
111+
112+
Returns:
113+
An :class:`Expression` representing the execution of this pipeline.
114+
"""
115+
return FunctionExpression("array", [_PipelineValueExpression(self)])
116+
117+
def to_scalar_expression(self) -> Expression:
118+
"""
119+
Converts this Pipeline into an expression that evaluates to a single scalar result.
120+
Used for 1:1 lookups or Aggregations when the subquery is expected to return a single value or object.
121+
122+
**Result Unwrapping:**
123+
For simpler access, scalar subqueries producing a single field automatically unwrap that value to the
124+
top level, ignoring the inner alias. If the subquery returns multiple fields, they are preserved as a map.
125+
126+
Example:
127+
>>> # Calculate average rating for each restaurant using a subquery
128+
>>> db.pipeline().collection("restaurants").define(Field.of("id").as_("rid")).add_fields(
129+
... db.pipeline().collection("reviews")
130+
... .where(Field.of("restaurant_id").equal(Variable("rid")))
131+
... .aggregate(AggregateFunction.average("rating").as_("value"))
132+
... .to_scalar_expression().as_("average_rating")
133+
... )
134+
135+
**Runtime Validation:**
136+
The runtime will validate that the result set contains exactly one item. It returns an error if the result has more than one item, and evaluates to `null` if the pipeline has zero results.
137+
138+
Returns:
139+
An :class:`Expression` representing the execution of this pipeline.
140+
"""
141+
return FunctionExpression("scalar", [_PipelineValueExpression(self)])
142+
93143
def _append(self, new_stage):
94144
"""
95145
Create a new Pipeline object with a new stage appended
@@ -391,9 +441,17 @@ def union(self, other: "_BasePipeline") -> "_BasePipeline":
391441
Args:
392442
other: The other `Pipeline` whose results will be unioned with this one.
393443
444+
Raises:
445+
ValueError: If the `other` pipeline is a relative pipeline (e.g. created without a client).
446+
394447
Returns:
395448
A new Pipeline object with this stage appended to the stage list
396449
"""
450+
if other._client is None:
451+
raise ValueError(
452+
"Union only supports combining root pipelines, doesn't support relative scope Pipeline "
453+
"like relative subcollection pipeline"
454+
)
397455
return self._append(stages.Union(other))
398456

399457
def unnest(
@@ -610,3 +668,53 @@ def distinct(self, *fields: str | Selectable) -> "_BasePipeline":
610668
A new Pipeline object with this stage appended to the stage list
611669
"""
612670
return self._append(stages.Distinct(*fields))
671+
672+
def define(self, *aliased_expressions: AliasedExpression) -> "_BasePipeline":
673+
"""
674+
Binds one or more expressions to Variables that can be accessed in subsequent stages
675+
or inner subqueries using `Variable`.
676+
677+
Each Variable is defined using an :class:`AliasedExpression`, which pairs an expression with
678+
a name (alias).
679+
680+
Example:
681+
>>> db.pipeline().collection("products").define(
682+
... Field.of("price").multiply(0.9).as_("discountedPrice"),
683+
... Field.of("stock").add(10).as_("newStock")
684+
... ).where(
685+
... Variable("discountedPrice").less_than(100)
686+
... ).select(Field.of("name"), Variable("newStock"))
687+
688+
Args:
689+
*aliased_expressions: One or more :class:`AliasedExpression` defining the Variable names and values.
690+
691+
Returns:
692+
A new Pipeline object with this stage appended to the stage list.
693+
"""
694+
return self._append(stages.Define(*aliased_expressions))
695+
696+
697+
class SubPipeline(_BasePipeline):
698+
"""
699+
A pipeline scoped to a subcollection, created without a database client.
700+
Cannot be executed directly; it must be used as a subquery within another pipeline.
701+
"""
702+
703+
_EXECUTE_ERROR_MSG = (
704+
"This pipeline was created without a database (e.g., as a subcollection pipeline) and "
705+
"cannot be executed directly. It can only be used as part of another pipeline."
706+
)
707+
708+
def execute(self, *args, **kwargs):
709+
"""
710+
Raises:
711+
RuntimeError: Always, as a subcollection pipeline cannot be executed directly.
712+
"""
713+
raise RuntimeError(self._EXECUTE_ERROR_MSG)
714+
715+
def stream(self, *args, **kwargs):
716+
"""
717+
Raises:
718+
RuntimeError: Always, as a subcollection pipeline cannot be streamed directly.
719+
"""
720+
raise RuntimeError(self._EXECUTE_ERROR_MSG)

packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ class Pipeline(_BasePipeline):
6363
subject to potential breaking changes in future releases.
6464
"""
6565

66+
_client: Client
67+
6668
def __init__(self, client: Client, *stages: stages.Stage):
6769
"""
6870
Initializes a Pipeline.
@@ -99,6 +101,9 @@ def execute(
99101
explain_metrics will be available on the returned list.
100102
additional_options (Optional[dict[str, Value | Constant]]): Additional options to pass to the query.
101103
These options will take precedence over method argument if there is a conflict (e.g. explain_options)
104+
105+
Raises:
106+
google.api_core.exceptions.GoogleAPIError: If there is a backend error.
102107
"""
103108
kwargs = {k: v for k, v in locals().items() if k != "self"}
104109
stream = PipelineStream(PipelineResult, self, **kwargs)
@@ -131,6 +136,9 @@ def stream(
131136
explain_metrics will be available on the returned generator.
132137
additional_options (Optional[dict[str, Value | Constant]]): Additional options to pass to the query.
133138
These options will take precedence over method argument if there is a conflict (e.g. explain_options)
139+
140+
Raises:
141+
google.api_core.exceptions.GoogleAPIError: If there is a backend error.
134142
"""
135143
kwargs = {k: v for k, v in locals().items() if k != "self"}
136144
return PipelineStream(PipelineResult, self, **kwargs)

packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_expressions.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,19 @@
2323
from abc import ABC, abstractmethod
2424
from enum import Enum
2525
from typing import (
26+
TYPE_CHECKING,
2627
Any,
2728
Generic,
2829
Sequence,
2930
TypeVar,
3031
)
3132

33+
if TYPE_CHECKING:
34+
from google.cloud.firestore_v1.base_pipeline import _BasePipeline
35+
3236
from google.cloud.firestore_v1._helpers import GeoPoint, decode_value, encode_value
3337
from google.cloud.firestore_v1.types.document import Value
38+
from google.cloud.firestore_v1.types.document import Pipeline as Pipeline_pb
3439
from google.cloud.firestore_v1.types.query import StructuredQuery as Query_pb
3540
from google.cloud.firestore_v1.vector import Vector
3641

@@ -258,6 +263,26 @@ def __get__(self, instance, owner):
258263
else:
259264
return self.instance_func.__get__(instance, owner)
260265

266+
@expose_as_static
267+
def get_field(self, key: Expression | str) -> "Expression":
268+
"""Accesses a field/property of the expression that evaluates to a Map or Document.
269+
270+
Example:
271+
>>> # Access the 'city' field from the 'address' map field.
272+
>>> Field.of("address").get_field("city")
273+
>>> # Create a map and access a field from it.
274+
>>> Map({"foo": "bar"}).get_field("foo")
275+
276+
Args:
277+
key: The key of the field to access.
278+
279+
Returns:
280+
A new `Expression` representing the value of the field.
281+
"""
282+
return FunctionExpression(
283+
"get_field", [self, self._cast_to_expr_or_convert_to_constant(key)]
284+
)
285+
261286
@expose_as_static
262287
def add(self, other: Expression | float) -> "Expression":
263288
"""Creates an expression that adds this expression to another expression or constant.
@@ -2709,6 +2734,17 @@ def _from_query_filter_pb(filter_pb, client):
27092734
raise TypeError(f"Unexpected filter type: {type(filter_pb)}")
27102735

27112736

2737+
class _PipelineValueExpression(Expression):
2738+
"""Internal wrapper to represent a pipeline as an expression."""
2739+
2740+
def __init__(self, pipeline: "_BasePipeline"):
2741+
self.pipeline = pipeline
2742+
2743+
def _to_pb(self) -> Value:
2744+
pipeline_pb = Pipeline_pb(stages=[s._to_pb() for s in self.pipeline.stages])
2745+
return Value(pipeline_value=pipeline_pb)
2746+
2747+
27122748
class Array(FunctionExpression):
27132749
"""
27142750
Creates an expression that creates a Firestore array value from an input list.
@@ -2889,3 +2925,42 @@ class Rand(FunctionExpression):
28892925

28902926
def __init__(self):
28912927
super().__init__("rand", [], use_infix_repr=False)
2928+
2929+
2930+
class Variable(Expression):
2931+
"""
2932+
Creates an expression that retrieves the value of a variable bound via `Pipeline.define`.
2933+
2934+
Example:
2935+
>>> # Define a variable "discountedPrice" and use it in a filter
2936+
>>> db.pipeline().collection("products").define(
2937+
... Field.of("price").multiply(0.9).as_("discountedPrice")
2938+
... ).where(Variable("discountedPrice").less_than(100))
2939+
2940+
Args:
2941+
name: The name of the variable to retrieve.
2942+
"""
2943+
2944+
def __init__(self, name: str):
2945+
self.name = name
2946+
2947+
def _to_pb(self) -> Value:
2948+
return Value(variable_reference_value=self.name)
2949+
2950+
2951+
class CurrentDocument(FunctionExpression):
2952+
"""
2953+
Creates an expression that represents the current document being processed.
2954+
2955+
This acts as a handle, allowing you to bind the entire document to a variable or pass the
2956+
document itself to a function or subquery.
2957+
2958+
Example:
2959+
>>> # Define the current document as a variable "doc"
2960+
>>> db.pipeline().collection("books").define(
2961+
... CurrentDocument().as_("doc")
2962+
... ).select(Variable("doc").get_field("title"))
2963+
"""
2964+
2965+
def __init__(self):
2966+
super().__init__("current_document", [])

packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_result.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
from google.cloud.firestore_v1.async_transaction import AsyncTransaction
5050
from google.cloud.firestore_v1.base_client import BaseClient
5151
from google.cloud.firestore_v1.base_document import BaseDocumentReference
52-
from google.cloud.firestore_v1.base_pipeline import _BasePipeline
5352
from google.cloud.firestore_v1.client import Client
5453
from google.cloud.firestore_v1.pipeline import Pipeline
5554
from google.cloud.firestore_v1.pipeline_expressions import Constant
@@ -190,7 +189,7 @@ def __init__(
190189
):
191190
# public
192191
self.transaction = transaction
193-
self.pipeline: _BasePipeline = pipeline
192+
self.pipeline: Pipeline | AsyncPipeline = pipeline
194193
self.execution_time: Timestamp | None = None
195194
# private
196195
self._client: Client | AsyncClient = pipeline._client

packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_source.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from google.cloud.firestore_v1 import pipeline_stages as stages
2525
from google.cloud.firestore_v1._helpers import DOCUMENT_PATH_DELIMITER
26-
from google.cloud.firestore_v1.base_pipeline import _BasePipeline
26+
from google.cloud.firestore_v1.base_pipeline import _BasePipeline, SubPipeline
2727

2828
if TYPE_CHECKING: # pragma: NO COVER
2929
from google.cloud.firestore_v1.async_client import AsyncClient
@@ -168,6 +168,32 @@ def literals(
168168
*documents: One or more documents to be returned by this stage. Each can be a `dict`
169169
of values of `Expression` or `CONSTANT_TYPE` types.
170170
Returns:
171-
A new Pipeline object with this stage appended to the stage list.
171+
A new pipeline instance targeting the specified literal documents
172172
"""
173173
return self._create_pipeline(stages.Literals(*documents))
174+
175+
@staticmethod
176+
def subcollection(path: str) -> SubPipeline:
177+
"""
178+
Initializes a pipeline scoped to a subcollection.
179+
180+
This method allows you to start a new pipeline that operates on a subcollection of the
181+
current document. It is intended to be used as a subquery.
182+
183+
**Note:** A pipeline created with `subcollection` cannot be executed directly using
184+
`execute()`. It must be used within a parent pipeline.
185+
186+
Example:
187+
>>> db.pipeline().collection("books").add_fields(
188+
... PipelineSource.subcollection("reviews")
189+
... .aggregate(AggregateFunction.average("rating").as_("avg_rating"))
190+
... .to_scalar_expression().as_("average_rating")
191+
... )
192+
193+
Args:
194+
path: The path of the subcollection.
195+
196+
Returns:
197+
A new pipeline instance targeting the specified subcollection
198+
"""
199+
return SubPipeline._create_with_stages(None, stages.Subcollection(path))

0 commit comments

Comments
 (0)