Skip to content

Commit da46650

Browse files
committed
Merge branch 'main' into firestore_pipelines_search
2 parents 5fb5a44 + 281eaae commit da46650

File tree

5 files changed

+210
-6
lines changed

5 files changed

+210
-6
lines changed

packages/google-cloud-firestore/google/cloud/firestore_v1/base_pipeline.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,62 @@ def distinct(self, *fields: str | Selectable) -> "_BasePipeline":
696696
"""
697697
return self._append(stages.Distinct(*fields))
698698

699+
def delete(self) -> "_BasePipeline":
700+
"""
701+
Deletes the documents from the current pipeline stage.
702+
703+
Example:
704+
>>> from google.cloud.firestore_v1.pipeline_expressions import Field
705+
>>> pipeline = client.pipeline().collection("logs")
706+
>>> # Delete all documents in the "logs" collection where "status" is "archived"
707+
>>> pipeline = pipeline.where(Field.of("status").equal("archived")).delete()
708+
>>> pipeline.execute()
709+
710+
Returns:
711+
A new Pipeline object with this stage appended to the stage list
712+
"""
713+
return self._append(stages.Delete())
714+
715+
def update(self, *transformed_fields: "Selectable") -> "_BasePipeline":
716+
"""
717+
Performs an update operation using documents from previous stages.
718+
719+
If called without `transformed_fields`, this method updates the documents in
720+
place based on the data flowing through the pipeline.
721+
722+
To update specific fields with new values, provide `Selectable` expressions that define the
723+
transformations to apply.
724+
725+
Example 1: Update a collection's schema by adding a new field and removing an old one.
726+
>>> from google.cloud.firestore_v1.pipeline_expressions import Constant
727+
>>> pipeline = client.pipeline().collection("books")
728+
>>> pipeline = pipeline.add_fields(Constant.of("Fiction").as_("genre"))
729+
>>> pipeline = pipeline.remove_fields("old_genre").update()
730+
>>> pipeline.execute()
731+
732+
Example 2: Update documents in place with data from literals.
733+
>>> pipeline = client.pipeline().literals(
734+
... {"__name__": client.collection("books").document("book1"), "status": "Updated"}
735+
... ).update()
736+
>>> pipeline.execute()
737+
738+
Example 3: Update documents from previous stages with specified transformations.
739+
>>> from google.cloud.firestore_v1.pipeline_expressions import Field, Constant
740+
>>> pipeline = client.pipeline().collection("books")
741+
>>> # Update the "status" field to "Discounted" for all books where price > 50
742+
>>> pipeline = pipeline.where(Field.of("price").greater_than(50))
743+
>>> pipeline = pipeline.update(Constant.of("Discounted").as_("status"))
744+
>>> pipeline.execute()
745+
746+
Args:
747+
*transformed_fields: Optional. The transformations to apply. If not provided,
748+
the update is performed in place based on the data flowing through the pipeline.
749+
750+
Returns:
751+
A new Pipeline object with this stage appended to the stage list
752+
"""
753+
return self._append(stages.Update(*transformed_fields))
754+
699755
def define(self, *aliased_expressions: AliasedExpression) -> "_BasePipeline":
700756
"""
701757
Binds one or more expressions to Variables that can be accessed in subsequent stages

packages/google-cloud-firestore/google/cloud/firestore_v1/pipeline_stages.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,27 @@ def _pb_args(self):
617617
return [self.condition._to_pb()]
618618

619619

620+
class Delete(Stage):
621+
"""Deletes documents matching the pipeline criteria."""
622+
623+
def __init__(self):
624+
super().__init__("delete")
625+
626+
def _pb_args(self) -> list[Value]:
627+
return []
628+
629+
630+
class Update(Stage):
631+
"""Updates documents with transformed fields."""
632+
633+
def __init__(self, *transformed_fields: Selectable):
634+
super().__init__("update")
635+
self.transformed_fields = list(transformed_fields)
636+
637+
def _pb_args(self) -> list[Value]:
638+
return [Selectable._to_value(self.transformed_fields)]
639+
640+
620641
class Define(Stage):
621642
"""Binds one or more expressions to variables."""
622643

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
data:
2+
dml_delete_coll:
3+
doc1: { score: 10 }
4+
doc2: { score: 60 }
5+
dml_update_coll:
6+
doc1: { status: "pending", score: 50 }
7+
8+
tests:
9+
- description: "Basic DML delete"
10+
pipeline:
11+
- Collection: dml_delete_coll
12+
- Where:
13+
FunctionExpression.less_than:
14+
- Field: score
15+
- Constant: 50
16+
- Delete:
17+
assert_end_state:
18+
dml_delete_coll/doc1: null
19+
dml_delete_coll/doc2: { score: 60 }
20+
assert_proto:
21+
pipeline:
22+
stages:
23+
- args:
24+
- referenceValue: /dml_delete_coll
25+
name: collection
26+
- args:
27+
- functionValue:
28+
args:
29+
- fieldReferenceValue: score
30+
- integerValue: '50'
31+
name: less_than
32+
name: where
33+
- name: delete
34+
35+
- description: "Basic DML update"
36+
pipeline:
37+
- Collection: dml_update_coll
38+
- Update:
39+
- AliasedExpression:
40+
- Constant: "active"
41+
- "status"
42+
assert_end_state:
43+
dml_update_coll/doc1: { status: "active", score: 50 }
44+
assert_proto:
45+
pipeline:
46+
stages:
47+
- args:
48+
- referenceValue: /dml_update_coll
49+
name: collection
50+
- args:
51+
- mapValue:
52+
fields:
53+
status:
54+
stringValue: active
55+
name: update

packages/google-cloud-firestore/tests/system/test_pipeline_acceptance.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def test_pipeline_expected_errors(test_dict, client):
119119
if "assert_results" in t
120120
or "assert_count" in t
121121
or "assert_results_approximate" in t
122+
or "assert_end_state" in t
122123
],
123124
ids=id_format,
124125
)
@@ -131,6 +132,7 @@ def test_pipeline_results(test_dict, client):
131132
test_dict.get("assert_results_approximate", None)
132133
)
133134
expected_count = test_dict.get("assert_count", None)
135+
expected_end_state = _parse_yaml_types(test_dict.get("assert_end_state", {}))
134136
pipeline = parse_pipeline(client, test_dict["pipeline"])
135137
# check if server responds as expected
136138
got_results = [snapshot.data() for snapshot in pipeline.stream()]
@@ -146,6 +148,19 @@ def test_pipeline_results(test_dict, client):
146148
)
147149
if expected_count is not None:
148150
assert len(got_results) == expected_count
151+
if expected_end_state:
152+
for doc_path, expected_content in expected_end_state.items():
153+
doc_ref = client.document(doc_path)
154+
snapshot = doc_ref.get()
155+
if expected_content is None:
156+
assert not snapshot.exists, (
157+
f"Expected {doc_path} to be absent, but it exists"
158+
)
159+
else:
160+
assert snapshot.exists, (
161+
f"Expected {doc_path} to exist, but it was absent"
162+
)
163+
assert snapshot.to_dict() == expected_content
149164

150165

151166
@pytest.mark.parametrize(
@@ -176,6 +191,7 @@ async def test_pipeline_expected_errors_async(test_dict, async_client):
176191
if "assert_results" in t
177192
or "assert_count" in t
178193
or "assert_results_approximate" in t
194+
or "assert_end_state" in t
179195
],
180196
ids=id_format,
181197
)
@@ -189,6 +205,7 @@ async def test_pipeline_results_async(test_dict, async_client):
189205
test_dict.get("assert_results_approximate", None)
190206
)
191207
expected_count = test_dict.get("assert_count", None)
208+
expected_end_state = _parse_yaml_types(test_dict.get("assert_end_state", {}))
192209
pipeline = parse_pipeline(async_client, test_dict["pipeline"])
193210
# check if server responds as expected
194211
got_results = [snapshot.data() async for snapshot in pipeline.stream()]
@@ -204,6 +221,19 @@ async def test_pipeline_results_async(test_dict, async_client):
204221
)
205222
if expected_count is not None:
206223
assert len(got_results) == expected_count
224+
if expected_end_state:
225+
for doc_path, expected_content in expected_end_state.items():
226+
doc_ref = async_client.document(doc_path)
227+
snapshot = await doc_ref.get()
228+
if expected_content is None:
229+
assert not snapshot.exists, (
230+
f"Expected {doc_path} to be absent, but it exists"
231+
)
232+
else:
233+
assert snapshot.exists, (
234+
f"Expected {doc_path} to exist, but it was absent"
235+
)
236+
assert snapshot.to_dict() == expected_content
207237

208238

209239
#################################################################################
@@ -223,7 +253,12 @@ def parse_pipeline(client, pipeline: list[dict[str, Any], str]):
223253
# find arguments if given
224254
if isinstance(stage, dict):
225255
stage_yaml_args = stage[stage_name]
226-
stage_obj = _apply_yaml_args_to_callable(stage_cls, client, stage_yaml_args)
256+
if stage_yaml_args is None:
257+
stage_obj = stage_cls()
258+
else:
259+
stage_obj = _apply_yaml_args_to_callable(
260+
stage_cls, client, stage_yaml_args
261+
)
227262
else:
228263
# yaml has no arguments
229264
stage_obj = stage_cls()
@@ -291,20 +326,21 @@ def _apply_yaml_args_to_callable(callable_obj, client, yaml_args):
291326
Helper to instantiate a class with yaml arguments. The arguments will be applied
292327
as positional or keyword arguments, based on type
293328
"""
294-
if isinstance(yaml_args, dict):
295-
return callable_obj(**_parse_expressions(client, yaml_args))
329+
parsed = _parse_expressions(client, yaml_args)
330+
if isinstance(yaml_args, dict) and isinstance(parsed, dict):
331+
return callable_obj(**parsed)
296332
elif isinstance(yaml_args, list) and not (
297333
callable_obj == expr.Constant
298334
or callable_obj == Vector
299335
or callable_obj == expr.Array
300336
):
301337
# yaml has an array of arguments. Treat as args
302-
return callable_obj(*_parse_expressions(client, yaml_args))
303-
elif yaml_args is None:
338+
return callable_obj(*parsed)
339+
elif yaml_args is None and callable_obj != expr.Constant:
304340
return callable_obj()
305341
else:
306342
# yaml has a single argument
307-
return callable_obj(_parse_expressions(client, yaml_args))
343+
return callable_obj(parsed)
308344

309345

310346
def _is_expr_string(yaml_str):

packages/google-cloud-firestore/tests/unit/v1/test_pipeline_stages.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,3 +1031,39 @@ def test_to_pb(self):
10311031
assert got_fn.args[0].field_reference_value == "city"
10321032
assert got_fn.args[1].string_value == "SF"
10331033
assert len(result.options) == 0
1034+
1035+
1036+
class TestDelete:
1037+
def _make_one(self):
1038+
return stages.Delete()
1039+
1040+
def test_to_pb(self):
1041+
instance = self._make_one()
1042+
result = instance._to_pb()
1043+
assert result.name == "delete"
1044+
assert len(result.args) == 0
1045+
assert len(result.options) == 0
1046+
1047+
1048+
class TestUpdate:
1049+
def _make_one(self, *args):
1050+
return stages.Update(*args)
1051+
1052+
def test_to_pb_empty(self):
1053+
instance = self._make_one()
1054+
result = instance._to_pb()
1055+
assert result.name == "update"
1056+
assert len(result.args) == 1
1057+
assert result.args[0].map_value.fields == {}
1058+
assert len(result.options) == 0
1059+
1060+
def test_to_pb_with_fields(self):
1061+
instance = self._make_one(
1062+
Field.of("score").add(10).as_("score"), Constant.of("active").as_("status")
1063+
)
1064+
result = instance._to_pb()
1065+
assert result.name == "update"
1066+
assert len(result.args) == 1
1067+
assert "score" in result.args[0].map_value.fields
1068+
assert "status" in result.args[0].map_value.fields
1069+
assert len(result.options) == 0

0 commit comments

Comments
 (0)