Skip to content

Commit 6657ecc

Browse files
author
Alex Wang
committed
Add decorator for named parallel branch
1 parent 5a8f410 commit 6657ecc

5 files changed

Lines changed: 206 additions & 15 deletions

File tree

examples/src/parallel/parallel_with_named_branches.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,50 @@
1-
"""Example demonstrating parallel operations with named branches."""
1+
"""Example demonstrating all parallel branch patterns."""
22

33
from typing import Any
44

55
from aws_durable_execution_sdk_python.config import ParallelBranch, ParallelConfig
6-
from aws_durable_execution_sdk_python.context import DurableContext
6+
from aws_durable_execution_sdk_python.context import (
7+
DurableContext,
8+
durable_parallel_branch,
9+
)
710
from aws_durable_execution_sdk_python.execution import durable_execution
811

912

13+
@durable_parallel_branch(name="fetch-orders")
14+
def fetch_orders(ctx: DurableContext) -> str:
15+
return ctx.step(lambda _: "orders-loaded", name="load_orders")
16+
17+
18+
@durable_parallel_branch()
19+
def fetch_preferences(ctx: DurableContext) -> str:
20+
return ctx.step(lambda _: "prefs-loaded", name="load_prefs")
21+
22+
1023
@durable_execution
1124
def handler(_event: Any, context: DurableContext) -> list[str]:
12-
"""Execute named parallel branches using ParallelBranch."""
25+
"""Execute parallel branches using all supported patterns."""
1326

1427
return context.parallel(
1528
functions=[
29+
# 1. Named parallel branch with ParallelBranch
1630
ParallelBranch(
1731
func=lambda ctx: ctx.step(
1832
lambda _: "user-data-loaded", name="load_user"
1933
),
2034
name="fetch-user-data",
2135
),
36+
# 2. Named parallel branch with decorator
37+
fetch_orders(),
38+
# 3. Unnamed parallel branch with decorator
39+
fetch_preferences(),
40+
# 4. Unnamed parallel branch with ParallelBranch
2241
ParallelBranch(
2342
func=lambda ctx: ctx.step(
24-
lambda _: "orders-loaded", name="load_orders"
43+
lambda _: "metrics-loaded", name="load_metrics"
2544
),
26-
name="fetch-order-history",
27-
),
28-
ParallelBranch(
29-
func=lambda ctx: ctx.step(lambda _: "prefs-loaded", name="load_prefs"),
30-
name="fetch-preferences",
3145
),
46+
# 5. No wrapper, just a raw callable
47+
lambda ctx: ctx.step(lambda _: "config-loaded", name="load_config"),
3248
],
3349
name="load_all_data",
3450
config=ParallelConfig(max_concurrency=3),

examples/test/parallel/test_parallel_with_named_branches.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
lambda_function_name="parallel with named branches",
1818
)
1919
def test_parallel_with_named_branches(durable_runner):
20-
"""Test parallel example with named branches using ParallelBranch."""
20+
"""Test parallel example with all branch patterns."""
2121
with durable_runner:
2222
result = durable_runner.run(input="test", timeout=10)
2323

@@ -26,18 +26,30 @@ def test_parallel_with_named_branches(durable_runner):
2626
"user-data-loaded",
2727
"orders-loaded",
2828
"prefs-loaded",
29+
"metrics-loaded",
30+
"config-loaded",
2931
]
3032

3133
# Get the parallel operation
3234
parallel_op = result.get_context("load_all_data")
3335
assert parallel_op is not None
3436
assert parallel_op.status is OperationStatus.SUCCEEDED
3537

36-
# Verify custom branch names from ParallelBranch
37-
assert len(parallel_op.child_operations) == 3
38-
child_names = {op.name for op in parallel_op.child_operations}
39-
expected_names = {"fetch-user-data", "fetch-order-history", "fetch-preferences"}
40-
assert child_names == expected_names
38+
# Verify branch names: named branches have custom names, unnamed use defaults
39+
assert len(parallel_op.child_operations) == 5
40+
41+
child_names = [op.name for op in parallel_op.child_operations]
42+
43+
# 1. Named ParallelBranch
44+
assert child_names[0] == "fetch-user-data"
45+
# 2. Named decorator
46+
assert child_names[1] == "fetch-orders"
47+
# 3. Unnamed decorator (None name falls back to index-based default)
48+
assert child_names[2] == "parallel-branch-2"
49+
# 4. Unnamed ParallelBranch (None name falls back to index-based default)
50+
assert child_names[3] == "parallel-branch-3"
51+
# 5. Raw callable (no ParallelBranch wrapper, index-based default)
52+
assert child_names[4] == "parallel-branch-4"
4153

4254
# Verify all children succeeded
4355
for child in parallel_op.child_operations:

src/aws_durable_execution_sdk_python/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from aws_durable_execution_sdk_python.config import ParallelBranch
1111
from aws_durable_execution_sdk_python.context import (
1212
DurableContext,
13+
durable_parallel_branch,
1314
durable_step,
1415
durable_wait_for_callback,
1516
durable_with_child_context,
@@ -39,6 +40,7 @@
3940
"ValidationError",
4041
"__version__",
4142
"durable_execution",
43+
"durable_parallel_branch",
4244
"durable_step",
4345
"durable_wait_for_callback",
4446
"durable_with_child_context",

src/aws_durable_execution_sdk_python/context.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,52 @@ def function_with_arguments(child_context: DurableContext):
121121
return wrapper
122122

123123

124+
def durable_parallel_branch(
125+
name: str | None = None,
126+
) -> Callable[
127+
[Callable[Concatenate[DurableContext, Params], T]],
128+
Callable[Params, ParallelBranch[T]],
129+
]:
130+
"""Wrap your callable into a named ParallelBranch for use with context.parallel().
131+
132+
This is a decorator factory — call it with an optional name to produce
133+
the actual decorator.
134+
135+
Args:
136+
name: Optional custom name for this branch. When provided, replaces
137+
the default "parallel-branch-{index}" naming in execution history.
138+
If None, the function's __name__ is used.
139+
140+
Example:
141+
@durable_parallel_branch(name="fetch-user-data")
142+
def fetch_user(ctx: DurableContext, user_id: str) -> dict:
143+
return ctx.step(lambda _: {"id": user_id, "name": "Jane"}, name="load_user")
144+
145+
@durable_parallel_branch(name="fetch-orders")
146+
def fetch_orders(ctx: DurableContext, user_id: str) -> list:
147+
return ctx.step(lambda _: ["order1", "order2"], name="load_orders")
148+
149+
# Usage in a durable handler:
150+
results = context.parallel(
151+
functions=[fetch_user(user_id), fetch_orders(user_id)],
152+
name="load-data",
153+
)
154+
"""
155+
156+
def decorator(
157+
func: Callable[Concatenate[DurableContext, Params], T],
158+
) -> Callable[Params, ParallelBranch[T]]:
159+
def wrapper(*args, **kwargs) -> ParallelBranch[T]:
160+
def function_with_arguments(ctx: DurableContext) -> T:
161+
return func(ctx, *args, **kwargs)
162+
163+
return ParallelBranch(func=function_with_arguments, name=name)
164+
165+
return wrapper
166+
167+
return decorator
168+
169+
124170
def durable_wait_for_callback(
125171
func: Callable[Concatenate[str, WaitForCallbackContext, Params], T],
126172
) -> Callable[Params, Callable[[str, WaitForCallbackContext], T]]:

tests/context_test.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414
Duration,
1515
InvokeConfig,
1616
MapConfig,
17+
ParallelBranch,
1718
ParallelConfig,
1819
StepConfig,
1920
)
2021
from aws_durable_execution_sdk_python.context import (
2122
Callback,
2223
DurableContext,
2324
ExecutionContext,
25+
durable_parallel_branch,
2426
)
2527
from aws_durable_execution_sdk_python.exceptions import (
2628
CallbackError,
@@ -2160,3 +2162,116 @@ def test_should_propagate_outer_parent_id_when_virtual_is_nested_in_virtual():
21602162

21612163

21622164
# endregion Virtual-context identity tests
2165+
2166+
2167+
# region durable_parallel_branch
2168+
2169+
2170+
def test_durable_parallel_branch_returns_parallel_branch_with_name():
2171+
"""Test that the decorator produces a ParallelBranch with the given name."""
2172+
2173+
@durable_parallel_branch(name="fetch-user-data")
2174+
def fetch_user(ctx: DurableContext, user_id: str) -> dict:
2175+
return {"id": user_id}
2176+
2177+
result = fetch_user("user-123")
2178+
2179+
assert isinstance(result, ParallelBranch)
2180+
assert result.name == "fetch-user-data"
2181+
2182+
2183+
def test_durable_parallel_branch_with_no_name():
2184+
"""Test that when name is None, ParallelBranch.name is None."""
2185+
2186+
@durable_parallel_branch()
2187+
def fetch_orders(ctx: DurableContext) -> list:
2188+
return ["order1"]
2189+
2190+
result = fetch_orders()
2191+
2192+
assert isinstance(result, ParallelBranch)
2193+
assert result.name is None
2194+
2195+
2196+
def test_durable_parallel_branch_callable_delegates_to_func():
2197+
"""Test that calling the ParallelBranch delegates to the wrapped function."""
2198+
2199+
@durable_parallel_branch(name="my-branch")
2200+
def my_branch(ctx: DurableContext, value: int) -> int:
2201+
return value * 2
2202+
2203+
branch = my_branch(21)
2204+
mock_ctx = Mock(spec=DurableContext)
2205+
2206+
result = branch(mock_ctx)
2207+
2208+
assert result == 42
2209+
2210+
2211+
def test_durable_parallel_branch_with_multiple_args_and_kwargs():
2212+
"""Test that positional and keyword arguments are correctly bound."""
2213+
2214+
@durable_parallel_branch(name="compute")
2215+
def compute(ctx: DurableContext, a: int, b: int, op: str = "add") -> str:
2216+
if op == "add":
2217+
return f"{a + b}"
2218+
return f"{a * b}"
2219+
2220+
branch = compute(3, 4, op="mul")
2221+
mock_ctx = Mock(spec=DurableContext)
2222+
2223+
result = branch(mock_ctx)
2224+
2225+
assert result == "12"
2226+
2227+
2228+
def test_durable_parallel_branch_passes_context_as_first_arg():
2229+
"""Test that the DurableContext is passed as the first argument to the function."""
2230+
received_ctx = None
2231+
2232+
@durable_parallel_branch(name="capture-ctx")
2233+
def capture(ctx: DurableContext) -> str:
2234+
nonlocal received_ctx
2235+
received_ctx = ctx
2236+
return "done"
2237+
2238+
branch = capture()
2239+
mock_ctx = Mock(spec=DurableContext)
2240+
branch(mock_ctx)
2241+
2242+
assert received_ctx is mock_ctx
2243+
2244+
2245+
def test_durable_parallel_branch_multiple_invocations_are_independent():
2246+
"""Test that calling the wrapper multiple times produces independent branches."""
2247+
2248+
@durable_parallel_branch(name="greet")
2249+
def greet(ctx: DurableContext, name: str) -> str:
2250+
return f"hello {name}"
2251+
2252+
branch_a = greet("Alice")
2253+
branch_b = greet("Bob")
2254+
2255+
mock_ctx = Mock(spec=DurableContext)
2256+
2257+
assert branch_a(mock_ctx) == "hello Alice"
2258+
assert branch_b(mock_ctx) == "hello Bob"
2259+
2260+
2261+
def test_durable_parallel_branch_is_compatible_with_parallel_functions_arg():
2262+
"""Test that the result can be used in a functions list alongside plain callables."""
2263+
2264+
@durable_parallel_branch(name="named-branch")
2265+
def named(ctx: DurableContext) -> str:
2266+
return "named"
2267+
2268+
plain = lambda ctx: "plain" # noqa: E731
2269+
2270+
functions = [named(), plain]
2271+
2272+
assert isinstance(functions[0], ParallelBranch)
2273+
assert callable(functions[0])
2274+
assert callable(functions[1])
2275+
2276+
2277+
# endregion durable_parallel_branch

0 commit comments

Comments
 (0)