|
1 | 1 | """Activity wrappers for executing LangGraph nodes and tasks.""" |
2 | 2 |
|
| 3 | +import asyncio |
3 | 4 | from collections.abc import Awaitable |
4 | 5 | from dataclasses import dataclass |
| 6 | +from datetime import timedelta |
5 | 7 | from inspect import iscoroutinefunction, signature |
6 | 8 | from typing import Any, Callable |
7 | 9 |
|
|
19 | 21 | cache_lookup, |
20 | 22 | cache_put, |
21 | 23 | ) |
| 24 | +from temporalio.contrib.workflow_streams import WorkflowStreamClient |
22 | 25 |
|
23 | 26 | # Per-run dedupe so we only warn once when a user passes a Store via |
24 | 27 | # graph.compile(store=...) / @entrypoint(store=...). Cleared by |
@@ -51,28 +54,54 @@ class ActivityOutput: |
51 | 54 |
|
52 | 55 | def wrap_activity( |
53 | 56 | func: Callable, |
| 57 | + *, |
| 58 | + streaming_topic: str | None = None, |
| 59 | + streaming_batch_interval: timedelta = timedelta(milliseconds=100), |
54 | 60 | ) -> Callable[[ActivityInput], Awaitable[ActivityOutput]]: |
55 | 61 | """Wrap a function as a Temporal activity that handles LangGraph config and interrupts.""" |
56 | | - # Graph nodes declare `runtime: Runtime[Ctx]` in their signature; tasks |
57 | | - # don't and instead reach for Runtime via get_runtime(). We re-inject the |
58 | | - # reconstructed Runtime only when the user function asks. |
59 | 62 | accepts_runtime = "runtime" in signature(func).parameters |
60 | 63 |
|
61 | 64 | async def wrapper(input: ActivityInput) -> ActivityOutput: |
62 | | - runtime = set_langgraph_config(input.langgraph_config) |
63 | | - kwargs = dict(input.kwargs) |
64 | | - if accepts_runtime: |
65 | | - kwargs["runtime"] = runtime |
66 | | - try: |
67 | | - if iscoroutinefunction(func): |
68 | | - result = await func(*input.args, **kwargs) |
69 | | - else: |
70 | | - result = func(*input.args, **kwargs) |
71 | | - if isinstance(result, Command): |
72 | | - return ActivityOutput(langgraph_command=result) |
73 | | - return ActivityOutput(result=result) |
74 | | - except GraphInterrupt as e: |
75 | | - return ActivityOutput(langgraph_interrupts=e.args[0]) |
| 65 | + async def run(stream_writer: Callable[[Any], None] | None) -> ActivityOutput: |
| 66 | + # Sync funcs run on a thread (so the loop keeps flushing the |
| 67 | + # stream client mid-execution); marshal writer calls back to |
| 68 | + # the loop thread because the client's flush event is an |
| 69 | + # asyncio.Event and isn't safe to set off-thread. |
| 70 | + effective_writer = stream_writer |
| 71 | + if not iscoroutinefunction(func) and stream_writer is not None: |
| 72 | + loop = asyncio.get_running_loop() |
| 73 | + inner_writer = stream_writer |
| 74 | + |
| 75 | + def thread_safe_writer(value: Any) -> None: |
| 76 | + loop.call_soon_threadsafe(inner_writer, value) |
| 77 | + |
| 78 | + effective_writer = thread_safe_writer |
| 79 | + |
| 80 | + runtime = set_langgraph_config( |
| 81 | + input.langgraph_config, stream_writer=effective_writer |
| 82 | + ) |
| 83 | + kwargs = dict(input.kwargs) |
| 84 | + if accepts_runtime: |
| 85 | + kwargs["runtime"] = runtime |
| 86 | + |
| 87 | + try: |
| 88 | + if iscoroutinefunction(func): |
| 89 | + result = await func(*input.args, **kwargs) |
| 90 | + else: |
| 91 | + result = await asyncio.to_thread(func, *input.args, **kwargs) |
| 92 | + if isinstance(result, Command): |
| 93 | + return ActivityOutput(langgraph_command=result) |
| 94 | + return ActivityOutput(result=result) |
| 95 | + except GraphInterrupt as e: |
| 96 | + return ActivityOutput(langgraph_interrupts=e.args[0]) |
| 97 | + |
| 98 | + if streaming_topic is None: |
| 99 | + return await run(stream_writer=None) |
| 100 | + async with WorkflowStreamClient.from_within_activity( |
| 101 | + batch_interval=streaming_batch_interval, |
| 102 | + ) as client: |
| 103 | + topic = client.topic(streaming_topic) |
| 104 | + return await run(stream_writer=topic.publish) |
76 | 105 |
|
77 | 106 | return wrapper |
78 | 107 |
|
|
0 commit comments