Skip to content

Commit 91a69c1

Browse files
committed
exposed options in raw_stage
1 parent 0d398dd commit 91a69c1

File tree

3 files changed

+36
-2
lines changed

3 files changed

+36
-2
lines changed

packages/google-cloud-firestore/google/cloud/firestore_v1/base_pipeline.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,12 @@ def unnest(
459459
"""
460460
return self._append(stages.Unnest(field, alias, options))
461461

462-
def raw_stage(self, name: str, *params: Expression) -> "_BasePipeline":
462+
def raw_stage(
463+
self,
464+
name: str,
465+
*params: Expression,
466+
options: dict[str, Expression | Value] | None = None,
467+
) -> "_BasePipeline":
463468
"""
464469
Adds a stage to the pipeline by specifying the stage name as an argument. This does not offer any
465470
type safety on the stage params and requires the caller to know the order (and optionally names)
@@ -477,11 +482,12 @@ def raw_stage(self, name: str, *params: Expression) -> "_BasePipeline":
477482
Args:
478483
name: The name of the stage.
479484
*params: A sequence of `Expression` objects representing the parameters for the stage.
485+
options: An optional dictionary of stage options.
480486
481487
Returns:
482488
A new Pipeline object with this stage appended to the stage list
483489
"""
484-
return self._append(stages.RawStage(name, *params))
490+
return self._append(stages.RawStage(name, *params, options=options or {}))
485491

486492
def offset(self, offset: int) -> "_BasePipeline":
487493
"""

packages/google-cloud-firestore/tests/unit/v1/test_async_pipeline.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,3 +448,17 @@ def test_async_pipeline_aggregate_with_groups():
448448
assert isinstance(result_ppl.stages[0], stages.Aggregate)
449449
assert list(result_ppl.stages[0].groups) == [Field.of("author")]
450450
assert list(result_ppl.stages[0].accumulators) == [Field.of("title")]
451+
452+
453+
def test_async_pipeline_raw_stage_with_options():
454+
from google.cloud.firestore_v1.base_vector_query import Field
455+
from google.cloud.firestore_v1.pipeline_stages import RawStage
456+
457+
start_ppl = _make_async_pipeline()
458+
result_ppl = start_ppl.raw_stage(
459+
"stage_name", Field.of("n"), options={"key": "val"}
460+
)
461+
assert len(start_ppl.stages) == 0
462+
assert len(result_ppl.stages) == 1
463+
assert isinstance(result_ppl.stages[0], RawStage)
464+
assert result_ppl.stages[0].options == {"key": "val"}

packages/google-cloud-firestore/tests/unit/v1/test_pipeline.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,3 +437,17 @@ def test_pipeline_aggregate_with_groups():
437437
assert isinstance(result_ppl.stages[0], stages.Aggregate)
438438
assert list(result_ppl.stages[0].groups) == [Field.of("author")]
439439
assert list(result_ppl.stages[0].accumulators) == [Field.of("title")]
440+
441+
442+
def test_pipeline_raw_stage_with_options():
443+
from google.cloud.firestore_v1.base_vector_query import Field
444+
from google.cloud.firestore_v1.pipeline_stages import RawStage
445+
446+
start_ppl = _make_pipeline()
447+
result_ppl = start_ppl.raw_stage(
448+
"stage_name", Field.of("n"), options={"key": "val"}
449+
)
450+
assert len(start_ppl.stages) == 0
451+
assert len(result_ppl.stages) == 1
452+
assert isinstance(result_ppl.stages[0], RawStage)
453+
assert result_ppl.stages[0].options == {"key": "val"}

0 commit comments

Comments
 (0)