Skip to content
This repository was archived by the owner on Mar 2, 2026. It is now read-only.

Commit 002950d

Browse files
committed
added create_from to pipeline source
1 parent 373d59c commit 002950d

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

google/cloud/firestore_v1/pipeline_source.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
from google.cloud.firestore_v1.client import Client
2323
from google.cloud.firestore_v1.async_client import AsyncClient
2424
from google.cloud.firestore_v1.base_document import BaseDocumentReference
25+
from google.cloud.firestore_v1.base_query import BaseQuery
26+
from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery
27+
from google.cloud.firestore_v1.base_collection import BaseCollectionReference
2528

2629

2730
PipelineType = TypeVar("PipelineType", bound=_BasePipeline)
@@ -43,6 +46,21 @@ def __init__(self, client: Client | AsyncClient):
4346
def _create_pipeline(self, source_stage):
4447
return self.client._pipeline_cls._create_with_stages(self.client, source_stage)
4548

49+
def create_from(self, query: "BaseQuery" | "BaseAggregationQuery" | "BaseCollectionReference") -> PipelineType:
50+
"""
51+
Create a pipeline from an existing query
52+
53+
Queries containing a `cursor` or `limit_to_last` are not currently supported
54+
55+
Args:
56+
query: the query to build the pipeline off of
57+
Raises:
58+
- NotImplementedError: raised if the query contains a `cursor` or `limit_to_last`
59+
Returns:
60+
a new pipeline instance representing the query
61+
"""
62+
return query._build_pipeline(self)
63+
4664
def collection(self, path: str | tuple[str]) -> PipelineType:
4765
"""
4866
Creates a new Pipeline that operates on a specified Firestore collection.

tests/unit/v1/test_pipeline_source.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License
14+
import mock
1415

1516
from google.cloud.firestore_v1.pipeline_source import PipelineSource
1617
from google.cloud.firestore_v1.pipeline import Pipeline
@@ -19,6 +20,8 @@
1920
from google.cloud.firestore_v1.async_client import AsyncClient
2021
from google.cloud.firestore_v1 import pipeline_stages as stages
2122
from google.cloud.firestore_v1.base_document import BaseDocumentReference
23+
from google.cloud.firestore_v1.query import Query
24+
from google.cloud.firestore_v1.async_query import AsyncQuery
2225

2326

2427
class TestPipelineSource:
@@ -27,6 +30,9 @@ class TestPipelineSource:
2730
def _make_client(self):
2831
return Client()
2932

33+
def _make_query(self):
34+
return Query(mock.Mock())
35+
3036
def test_make_from_client(self):
3137
instance = self._make_client().pipeline()
3238
assert isinstance(instance, PipelineSource)
@@ -36,6 +42,23 @@ def test_create_pipeline(self):
3642
ppl = instance._create_pipeline(None)
3743
assert isinstance(ppl, self._expected_pipeline_type)
3844

45+
def test_create_from_mock(self):
46+
mock_query = mock.Mock()
47+
expected = object()
48+
mock_query._build_pipeline.return_value = expected
49+
instance = self._make_client().pipeline()
50+
got = instance.create_from(mock_query)
51+
assert got == expected
52+
assert mock_query._build_pipeline.call_count == 1
53+
assert mock_query._build_pipeline.call_args_list[0][0][0] == instance
54+
55+
def test_create_from_query(self):
56+
query = self._make_query()
57+
instance = self._make_client().pipeline()
58+
ppl = instance.create_from(query)
59+
assert isinstance(ppl, self._expected_pipeline_type)
60+
assert len(ppl.stages) == 1
61+
3962
def test_collection(self):
4063
instance = self._make_client().pipeline()
4164
ppl = instance.collection("path")
@@ -98,3 +121,6 @@ class TestPipelineSourceWithAsyncClient(TestPipelineSource):
98121

99122
def _make_client(self):
100123
return AsyncClient()
124+
125+
def _make_query(self):
126+
return AsyncQuery(mock.Mock())

0 commit comments

Comments
 (0)