Skip to content

Commit 8e11436

Browse files
author
Alex Wang
committed
feat: parallel and map branch name
1 parent 85f2d24 commit 8e11436

15 files changed

Lines changed: 519 additions & 12 deletions

File tree

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,6 @@ dist/
3131
.kiro/
3232

3333
/examples/build/*
34-
/examples/*.zip
34+
/examples/*.zip
35+
36+
.env

examples/examples-catalog.json

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,28 @@
580580
"ApplicationLogLevel": "DEBUG",
581581
"LogFormat": "JSON"
582582
}
583-
}
583+
},
584+
{
585+
"name": "Map with Item Namer",
586+
"description": "Map operation with custom item_namer for iteration naming",
587+
"handler": "map_with_item_namer.handler",
588+
"integration": true,
589+
"durableConfig": {
590+
"RetentionPeriodInDays": 7,
591+
"ExecutionTimeout": 300
592+
},
593+
"path": "./src/map/map_with_item_namer.py"
594+
},
595+
{
596+
"name": "Parallel with Named Branches",
597+
"description": "Parallel operation with named branches using ParallelBranch",
598+
"handler": "parallel_with_named_branches.handler",
599+
"integration": true,
600+
"durableConfig": {
601+
"RetentionPeriodInDays": 7,
602+
"ExecutionTimeout": 300
603+
},
604+
"path": "./src/parallel/parallel_with_named_branches.py"
605+
}
584606
]
585607
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""Example demonstrating map operations with custom iteration naming."""
2+
3+
from typing import Any
4+
5+
from aws_durable_execution_sdk_python.config import MapConfig
6+
from aws_durable_execution_sdk_python.context import DurableContext
7+
from aws_durable_execution_sdk_python.execution import durable_execution
8+
9+
10+
@durable_execution
11+
def handler(_event: Any, context: DurableContext) -> list[str]:
12+
"""Process orders using context.map() with custom iteration names."""
13+
orders = [
14+
{"id": "order-101", "amount": 25},
15+
{"id": "order-102", "amount": 50},
16+
{"id": "order-103", "amount": 75},
17+
]
18+
19+
return context.map(
20+
inputs=orders,
21+
func=lambda ctx, order, index, _: ctx.step(
22+
lambda _: f"processed-{order['id']}-${order['amount']}",
23+
name=f"process_{order['id']}",
24+
),
25+
name="process_orders",
26+
config=MapConfig(
27+
max_concurrency=2,
28+
item_namer=lambda order, index: f"order-{order['id']}",
29+
),
30+
).get_results()
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""Example demonstrating parallel operations with named branches."""
2+
3+
from typing import Any
4+
5+
from aws_durable_execution_sdk_python.config import ParallelBranch, ParallelConfig
6+
from aws_durable_execution_sdk_python.context import DurableContext
7+
from aws_durable_execution_sdk_python.execution import durable_execution
8+
9+
10+
@durable_execution
11+
def handler(_event: Any, context: DurableContext) -> list[str]:
12+
"""Execute named parallel branches using ParallelBranch."""
13+
14+
return context.parallel(
15+
functions=[
16+
ParallelBranch(
17+
func=lambda ctx: ctx.step(
18+
lambda _: "user-data-loaded", name="load_user"
19+
),
20+
name="fetch-user-data",
21+
),
22+
ParallelBranch(
23+
func=lambda ctx: ctx.step(
24+
lambda _: "orders-loaded", name="load_orders"
25+
),
26+
name="fetch-order-history",
27+
),
28+
ParallelBranch(
29+
func=lambda ctx: ctx.step(lambda _: "prefs-loaded", name="load_prefs"),
30+
name="fetch-preferences",
31+
),
32+
],
33+
name="load_all_data",
34+
config=ParallelConfig(max_concurrency=3),
35+
).get_results()
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""Tests for map_with_item_namer example."""
2+
3+
import pytest
4+
from src.map import map_with_item_namer
5+
from test.conftest import deserialize_operation_payload
6+
7+
from aws_durable_execution_sdk_python.execution import InvocationStatus
8+
from aws_durable_execution_sdk_python.lambda_service import (
9+
OperationStatus,
10+
)
11+
12+
13+
@pytest.mark.example
14+
@pytest.mark.durable_execution(
15+
handler=map_with_item_namer.handler,
16+
lambda_function_name="map with item namer",
17+
)
18+
def test_map_with_item_namer(durable_runner):
19+
"""Test map example with custom item_namer for iteration naming."""
20+
with durable_runner:
21+
result = durable_runner.run(input="test", timeout=10)
22+
23+
assert result.status is InvocationStatus.SUCCEEDED
24+
assert deserialize_operation_payload(result.result) == [
25+
"processed-order-101-$25",
26+
"processed-order-102-$50",
27+
"processed-order-103-$75",
28+
]
29+
30+
# Get the map operation
31+
map_op = result.get_context("process_orders")
32+
assert map_op is not None
33+
assert map_op.status is OperationStatus.SUCCEEDED
34+
35+
# Verify custom iteration names from item_namer
36+
assert len(map_op.child_operations) == 3
37+
child_names = {op.name for op in map_op.child_operations}
38+
expected_names = {"order-order-101", "order-order-102", "order-order-103"}
39+
assert child_names == expected_names
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""Tests for parallel_with_named_branches example."""
2+
3+
import pytest
4+
from src.parallel import parallel_with_named_branches
5+
from test.conftest import deserialize_operation_payload
6+
7+
from aws_durable_execution_sdk_python.execution import InvocationStatus
8+
from aws_durable_execution_sdk_python.lambda_service import (
9+
OperationStatus,
10+
OperationType,
11+
)
12+
13+
14+
@pytest.mark.example
15+
@pytest.mark.durable_execution(
16+
handler=parallel_with_named_branches.handler,
17+
lambda_function_name="parallel with named branches",
18+
)
19+
def test_parallel_with_named_branches(durable_runner):
20+
"""Test parallel example with named branches using ParallelBranch."""
21+
with durable_runner:
22+
result = durable_runner.run(input="test", timeout=10)
23+
24+
assert result.status is InvocationStatus.SUCCEEDED
25+
assert deserialize_operation_payload(result.result) == [
26+
"user-data-loaded",
27+
"orders-loaded",
28+
"prefs-loaded",
29+
]
30+
31+
# Get the parallel operation
32+
parallel_op = result.get_context("load_all_data")
33+
assert parallel_op is not None
34+
assert parallel_op.status is OperationStatus.SUCCEEDED
35+
36+
# Verify custom branch names from ParallelBranch
37+
assert len(parallel_op.child_operations) == 3
38+
child_names = {op.name for op in parallel_op.child_operations}
39+
expected_names = {"fetch-user-data", "fetch-order-history", "fetch-preferences"}
40+
assert child_names == expected_names
41+
42+
# Verify all children succeeded
43+
for child in parallel_op.child_operations:
44+
assert child.operation_type == OperationType.CONTEXT
45+
assert child.status is OperationStatus.SUCCEEDED

src/aws_durable_execution_sdk_python/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# Helper decorators - commonly used for step functions
88
# Concurrency
99
from aws_durable_execution_sdk_python.concurrency.models import BatchResult
10+
from aws_durable_execution_sdk_python.config import ParallelBranch
1011
from aws_durable_execution_sdk_python.context import (
1112
DurableContext,
1213
durable_step,
@@ -27,11 +28,13 @@
2728
# Essential context types - passed to user functions
2829
from aws_durable_execution_sdk_python.types import StepContext
2930

31+
3032
__all__ = [
3133
"BatchResult",
3234
"DurableContext",
3335
"DurableExecutionsError",
3436
"InvocationError",
37+
"ParallelBranch",
3538
"StepContext",
3639
"ValidationError",
3740
"__version__",

src/aws_durable_execution_sdk_python/concurrency/executor.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,14 @@ def execute_item(
194194
"""Execute a single executable in a child context and return the result."""
195195
raise NotImplementedError
196196

197+
def get_iteration_name(self, index: int) -> str:
198+
"""Get the display name for an iteration/branch at the given index.
199+
200+
Subclasses can override this to provide custom naming (e.g., from item_namer
201+
or branch names). The default returns "{name_prefix}{index}".
202+
"""
203+
return f"{self.name_prefix}{index}"
204+
197205
def execute(
198206
self, execution_state: ExecutionState, executor_context: DurableContext
199207
) -> BatchResult[ResultType]:
@@ -410,7 +418,7 @@ def _execute_item_in_child_context(
410418
operation_id: str = executor_context._create_step_id_for_logical_step( # noqa: SLF001
411419
executable.index
412420
)
413-
name: str = f"{self.name_prefix}{executable.index}"
421+
name: str = self.get_iteration_name(executable.index)
414422
is_virtual: bool = self.nesting_type is NestingType.FLAT
415423

416424
child_context: DurableContext = executor_context.create_child_context(

src/aws_durable_execution_sdk_python/config.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from aws_durable_execution_sdk_python.exceptions import ValidationError
1111

12+
1213
P = TypeVar("P") # Payload type
1314
R = TypeVar("R") # Result type
1415
T = TypeVar("T")
@@ -245,6 +246,41 @@ class ParallelConfig:
245246
nesting_type: NestingType = NestingType.NESTED
246247

247248

249+
@dataclass(frozen=True)
250+
class ParallelBranch(Generic[T]):
251+
"""A named branch for parallel execution.
252+
253+
Use this to provide custom names for parallel branches, improving
254+
observability in execution history.
255+
256+
Type Parameters:
257+
T: The return type of the branch function.
258+
259+
Args:
260+
func: The callable to execute in this branch. Receives a DurableContext.
261+
name: Optional custom name for this branch. When provided, replaces
262+
the default "parallel-branch-{index}" naming in execution history.
263+
This affects observability but not replay determinism.
264+
265+
Example:
266+
context.parallel(
267+
functions=[
268+
ParallelBranch(func=lambda ctx: fetch_user(ctx), name="fetch-user-data"),
269+
ParallelBranch(func=lambda ctx: fetch_orders(ctx), name="fetch-order-history"),
270+
],
271+
name="load-data",
272+
config=ParallelConfig(max_concurrency=2),
273+
)
274+
"""
275+
276+
func: Callable
277+
name: str | None = None
278+
279+
def __call__(self, *args, **kwargs):
280+
"""Delegate to the wrapped function, making ParallelBranch itself callable."""
281+
return self.func(*args, **kwargs)
282+
283+
248284
class StepSemantics(Enum):
249285
AT_MOST_ONCE_PER_RETRY = "AT_MOST_ONCE_PER_RETRY"
250286
AT_LEAST_ONCE_PER_RETRY = "AT_LEAST_ONCE_PER_RETRY"
@@ -354,12 +390,15 @@ class ItemBatcher(Generic[T]):
354390

355391

356392
@dataclass(frozen=True)
357-
class MapConfig:
393+
class MapConfig(Generic[T]):
358394
"""Configuration options for map operations over collections.
359395
360396
This class configures how map operations process collections of items,
361397
including concurrency, batching, completion criteria, and serialization.
362398
399+
Type Parameters:
400+
T: The type of items being processed in the map operation.
401+
363402
Args:
364403
max_concurrency: Maximum number of items to process concurrently.
365404
If None, no limit is imposed and all items are processed concurrently.
@@ -402,13 +441,25 @@ class MapConfig:
402441
- NESTED: Each item runs in its own isolated context (default)
403442
- FLAT: All items share the same parent context
404443
444+
item_namer: Optional callable to generate custom names for each map iteration.
445+
When provided, replaces the default "map-item-{index}" naming scheme.
446+
Receives the item and its index, and returns a string name for that iteration.
447+
This affects observability (execution history names) but not replay determinism.
448+
If None, uses the default naming: "map-item-{index}".
449+
405450
Example:
406451
# Process 5 items at a time, batch by count, require all to succeed
407452
config = MapConfig(
408453
max_concurrency=5,
409454
item_batcher=ItemBatcher(max_items_per_batch=10),
410455
completion_config=CompletionConfig.all_successful()
411456
)
457+
458+
# With custom iteration names
459+
config = MapConfig(
460+
max_concurrency=5,
461+
item_namer=lambda item, index: f"process-order-{item.id}"
462+
)
412463
"""
413464

414465
max_concurrency: int | None = None
@@ -418,6 +469,7 @@ class MapConfig:
418469
item_serdes: SerDes | None = None
419470
summary_generator: SummaryGenerator | None = None
420471
nesting_type: NestingType = NestingType.NESTED
472+
item_namer: Callable[[T, int], str] | None = None
421473

422474

423475
@dataclass(frozen=True)

src/aws_durable_execution_sdk_python/context.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Duration,
1313
InvokeConfig,
1414
MapConfig,
15+
ParallelBranch,
1516
ParallelConfig,
1617
StepConfig,
1718
WaitForCallbackConfig,
@@ -55,6 +56,7 @@
5556
WaitForConditionCheckContext,
5657
)
5758

59+
5860
if TYPE_CHECKING:
5961
from collections.abc import Callable, Sequence
6062

@@ -496,7 +498,7 @@ def map_in_child_context() -> BatchResult[R]:
496498

497499
def parallel(
498500
self,
499-
functions: Sequence[Callable[[DurableContext], T]],
501+
functions: Sequence[Callable[[DurableContext], T] | ParallelBranch[T]],
500502
name: str | None = None,
501503
config: ParallelConfig | None = None,
502504
) -> BatchResult[T]:

0 commit comments

Comments
 (0)