|
14 | 14 | Duration, |
15 | 15 | InvokeConfig, |
16 | 16 | MapConfig, |
| 17 | + ParallelBranch, |
17 | 18 | ParallelConfig, |
18 | 19 | StepConfig, |
19 | 20 | ) |
20 | 21 | from aws_durable_execution_sdk_python.context import ( |
21 | 22 | Callback, |
22 | 23 | DurableContext, |
23 | 24 | ExecutionContext, |
| 25 | + durable_parallel_branch, |
24 | 26 | ) |
25 | 27 | from aws_durable_execution_sdk_python.exceptions import ( |
26 | 28 | CallbackError, |
@@ -2160,3 +2162,116 @@ def test_should_propagate_outer_parent_id_when_virtual_is_nested_in_virtual(): |
2160 | 2162 |
|
2161 | 2163 |
|
2162 | 2164 | # endregion Virtual-context identity tests |
| 2165 | + |
| 2166 | + |
| 2167 | +# region durable_parallel_branch |
| 2168 | + |
| 2169 | + |
| 2170 | +def test_durable_parallel_branch_returns_parallel_branch_with_name(): |
| 2171 | + """Test that the decorator produces a ParallelBranch with the given name.""" |
| 2172 | + |
| 2173 | + @durable_parallel_branch(name="fetch-user-data") |
| 2174 | + def fetch_user(ctx: DurableContext, user_id: str) -> dict: |
| 2175 | + return {"id": user_id} |
| 2176 | + |
| 2177 | + result = fetch_user("user-123") |
| 2178 | + |
| 2179 | + assert isinstance(result, ParallelBranch) |
| 2180 | + assert result.name == "fetch-user-data" |
| 2181 | + |
| 2182 | + |
| 2183 | +def test_durable_parallel_branch_with_no_name(): |
| 2184 | + """Test that when name is None, ParallelBranch.name is None.""" |
| 2185 | + |
| 2186 | + @durable_parallel_branch() |
| 2187 | + def fetch_orders(ctx: DurableContext) -> list: |
| 2188 | + return ["order1"] |
| 2189 | + |
| 2190 | + result = fetch_orders() |
| 2191 | + |
| 2192 | + assert isinstance(result, ParallelBranch) |
| 2193 | + assert result.name is None |
| 2194 | + |
| 2195 | + |
| 2196 | +def test_durable_parallel_branch_callable_delegates_to_func(): |
| 2197 | + """Test that calling the ParallelBranch delegates to the wrapped function.""" |
| 2198 | + |
| 2199 | + @durable_parallel_branch(name="my-branch") |
| 2200 | + def my_branch(ctx: DurableContext, value: int) -> int: |
| 2201 | + return value * 2 |
| 2202 | + |
| 2203 | + branch = my_branch(21) |
| 2204 | + mock_ctx = Mock(spec=DurableContext) |
| 2205 | + |
| 2206 | + result = branch(mock_ctx) |
| 2207 | + |
| 2208 | + assert result == 42 |
| 2209 | + |
| 2210 | + |
| 2211 | +def test_durable_parallel_branch_with_multiple_args_and_kwargs(): |
| 2212 | + """Test that positional and keyword arguments are correctly bound.""" |
| 2213 | + |
| 2214 | + @durable_parallel_branch(name="compute") |
| 2215 | + def compute(ctx: DurableContext, a: int, b: int, op: str = "add") -> str: |
| 2216 | + if op == "add": |
| 2217 | + return f"{a + b}" |
| 2218 | + return f"{a * b}" |
| 2219 | + |
| 2220 | + branch = compute(3, 4, op="mul") |
| 2221 | + mock_ctx = Mock(spec=DurableContext) |
| 2222 | + |
| 2223 | + result = branch(mock_ctx) |
| 2224 | + |
| 2225 | + assert result == "12" |
| 2226 | + |
| 2227 | + |
| 2228 | +def test_durable_parallel_branch_passes_context_as_first_arg(): |
| 2229 | + """Test that the DurableContext is passed as the first argument to the function.""" |
| 2230 | + received_ctx = None |
| 2231 | + |
| 2232 | + @durable_parallel_branch(name="capture-ctx") |
| 2233 | + def capture(ctx: DurableContext) -> str: |
| 2234 | + nonlocal received_ctx |
| 2235 | + received_ctx = ctx |
| 2236 | + return "done" |
| 2237 | + |
| 2238 | + branch = capture() |
| 2239 | + mock_ctx = Mock(spec=DurableContext) |
| 2240 | + branch(mock_ctx) |
| 2241 | + |
| 2242 | + assert received_ctx is mock_ctx |
| 2243 | + |
| 2244 | + |
| 2245 | +def test_durable_parallel_branch_multiple_invocations_are_independent(): |
| 2246 | + """Test that calling the wrapper multiple times produces independent branches.""" |
| 2247 | + |
| 2248 | + @durable_parallel_branch(name="greet") |
| 2249 | + def greet(ctx: DurableContext, name: str) -> str: |
| 2250 | + return f"hello {name}" |
| 2251 | + |
| 2252 | + branch_a = greet("Alice") |
| 2253 | + branch_b = greet("Bob") |
| 2254 | + |
| 2255 | + mock_ctx = Mock(spec=DurableContext) |
| 2256 | + |
| 2257 | + assert branch_a(mock_ctx) == "hello Alice" |
| 2258 | + assert branch_b(mock_ctx) == "hello Bob" |
| 2259 | + |
| 2260 | + |
| 2261 | +def test_durable_parallel_branch_is_compatible_with_parallel_functions_arg(): |
| 2262 | + """Test that the result can be used in a functions list alongside plain callables.""" |
| 2263 | + |
| 2264 | + @durable_parallel_branch(name="named-branch") |
| 2265 | + def named(ctx: DurableContext) -> str: |
| 2266 | + return "named" |
| 2267 | + |
| 2268 | + plain = lambda ctx: "plain" # noqa: E731 |
| 2269 | + |
| 2270 | + functions = [named(), plain] |
| 2271 | + |
| 2272 | + assert isinstance(functions[0], ParallelBranch) |
| 2273 | + assert callable(functions[0]) |
| 2274 | + assert callable(functions[1]) |
| 2275 | + |
| 2276 | + |
| 2277 | +# endregion durable_parallel_branch |
0 commit comments