Skip to content

Commit 7ea54e6

Browse files
authored
LangGraph streaming with workflow streams (#1500)
* first pass at langgraph streaming * Trim obvious comments in langgraph activity/config * Remove unrelated runtime tests * Tidy langgraph streaming tests * don't store workflowstream * remove timeout * fix lint * Make langgraph streaming opt-in via streaming_topic * add streaming support disclaimer * mention streaming in readme * Validate WorkflowStream registration when streaming_topic is set The LangGraph interceptor now checks at workflow start that a WorkflowStream has been registered (via the publish signal handler) when the plugin was configured with streaming_topic. Misconfigured workflows fail fast with a clear error pointing at @workflow.init, instead of silently producing no-op streams. * Stream from workflow-side LangGraph nodes via in-workflow WorkflowStream Wrap execute_in='workflow' nodes with wrap_workflow(), which mirrors wrap_activity() and (when streaming_topic is set) overrides the LangGraph Runtime's stream_writer to publish synchronously to the in-workflow WorkflowStream — no signal round-trip. Parametrized the streaming test over execute_in to cover both paths. * Document streaming feature in README and plugin docstring Expand the README streaming section with a self-contained snippet (plugin, WorkflowStream in __init__, external subscriber loop), an explicit callout that streaming_topic only covers stream_mode='custom' with an astream() bridge example for other modes, and at-least-once retry semantics. Add an Args section to LangGraphPlugin's docstring covering all constructor parameters. * Drop compose-mechanisms paragraph from streaming README * Support sync nodes for streaming and execute_in='workflow' Pick the raw user function from runnable.func instead of LangGraph's async runnable.afunc adapter, which wraps sync nodes in loop.run_in_executor — that's incompatible with the workflow event loop. wrap_activity now schedules sync funcs on a thread via asyncio.to_thread so the activity loop stays free for the streaming flusher, with stream_writer calls marshaled back to the loop thread to keep the workflow_streams client's asyncio.Event safe. Parametrize the streaming test over (execute_in, sync/async). * Fix astream-publish test race with subscriber ack The workflow was publishing chunk_b and the done marker in the same workflow task as its return, leaving no chance for the subscriber's next poll to land on a running workflow. Add an ack_done signal the subscriber sends after seeing done; the workflow waits for it before returning. Also hoist a signature() lookup out of the activity wrapper hot path. * Add CODEOWNERS entries for langgraph contrib * Drop blank line after wrap_activity docstring (D202) * Skip workflow-side streaming tests on Python 3.10 LangGraph's astream uses asyncio.create_task internally, and Python 3.10 doesn't propagate contextvars through new tasks. As a result get_stream_writer() returns "outside of a runnable context" when the node executes in-workflow under streaming_topic. Activity-side streaming is unaffected because the activity wrapper sets the runtime contextvar explicitly within the same task as the user node. This matches the existing 3.10 limitation already documented on the plugin (interrupts and the Functional API are also gated on 3.11+). * Move 3.10 skip onto the parametrize value * Fix streaming-ws test race with subscriber ack In the async-workflow case the node runs inline in the workflow with no awaits, so ainvoke and the workflow return in the same task as the publishes. The subscriber's first poll lands after completion and gets zero items. Add an ack_done signal the subscriber sends after seeing done; the workflow waits for it before returning. Mirrors 32818b1 for AstreamPublishWorkflow.
1 parent 4d6348e commit 7ea54e6

9 files changed

Lines changed: 467 additions & 44 deletions

File tree

.github/CODEOWNERS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
# as well as @temporalio/sdk, so the SDK team can continue to
1212
# manage repo-wide concerns.
1313
/temporalio/contrib/google_adk_agents/ @temporalio/ai-sdk @temporalio/sdk
14+
/temporalio/contrib/langgraph/ @temporalio/ai-sdk @temporalio/sdk
1415
/temporalio/contrib/langsmith/ @temporalio/ai-sdk @temporalio/sdk
1516
/temporalio/contrib/openai_agents/ @temporalio/ai-sdk @temporalio/sdk
1617
/tests/contrib/google_adk_agents/ @temporalio/ai-sdk @temporalio/sdk
18+
/tests/contrib/langgraph/ @temporalio/ai-sdk @temporalio/sdk
1719
/tests/contrib/langsmith/ @temporalio/ai-sdk @temporalio/sdk
1820
/tests/contrib/openai_agents/ @temporalio/ai-sdk @temporalio/sdk

temporalio/contrib/langgraph/README.md

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,107 @@ await g.ainvoke({...}, context=Context(user_id="alice"))
143143

144144
Your `context` object must be serializable by the configured Temporal payload converter, since it crosses the Activity boundary.
145145

146+
## Streaming
147+
148+
When `streaming_topic` is set on `LangGraphPlugin`, calls to `langgraph.config.get_stream_writer()` inside a node publish to the named topic on the workflow's [`WorkflowStream`](https://github.com/temporalio/sdk-python/tree/main/temporalio/contrib/workflow_streams). Activity-side nodes publish via `WorkflowStreamClient` (a signal carrying batched items, controlled by `streaming_batch_interval`); workflow-side nodes publish synchronously to the in-workflow stream (no signal). External subscribers consume the stream with `WorkflowStreamClient.create(...).topic(...).subscribe(...)`.
149+
150+
The workflow **must** construct `WorkflowStream()` in its `@workflow.init` (i.e. `__init__`)
151+
152+
```python
153+
from datetime import timedelta
154+
from typing import Any
155+
156+
from langgraph.config import get_stream_writer
157+
from langgraph.graph import START, StateGraph
158+
from typing_extensions import TypedDict
159+
160+
from temporalio import workflow
161+
from temporalio.client import Client
162+
from temporalio.contrib.langgraph import LangGraphPlugin, graph
163+
from temporalio.contrib.workflow_streams import WorkflowStream, WorkflowStreamClient
164+
from temporalio.worker import Worker
165+
166+
167+
class State(TypedDict):
168+
value: str
169+
170+
171+
async def token_node(state: State) -> dict[str, str]:
172+
writer = get_stream_writer()
173+
for token in ["hello", " ", "world"]:
174+
writer({"token": token})
175+
writer({"done": True})
176+
return {"value": "hello world"}
177+
178+
179+
@workflow.defn
180+
class StreamingWorkflow:
181+
def __init__(self) -> None:
182+
# Required when streaming_topic is set on the plugin.
183+
_ = WorkflowStream()
184+
self.app = graph("streaming").compile()
185+
186+
@workflow.run
187+
async def run(self) -> str:
188+
result = await self.app.ainvoke({"value": ""})
189+
return result["value"]
190+
191+
192+
async def main(client: Client) -> None:
193+
g = StateGraph(State)
194+
g.add_node("token_node", token_node, metadata={"execute_in": "activity"})
195+
g.add_edge(START, "token_node")
196+
197+
async with Worker(
198+
client,
199+
task_queue="streaming-tq",
200+
workflows=[StreamingWorkflow],
201+
plugins=[
202+
LangGraphPlugin(
203+
graphs={"streaming": g},
204+
default_activity_options={
205+
"start_to_close_timeout": timedelta(seconds=10)
206+
},
207+
streaming_topic="tokens",
208+
)
209+
],
210+
):
211+
handle = await client.start_workflow(
212+
StreamingWorkflow.run, id="streaming-wf", task_queue="streaming-tq"
213+
)
214+
215+
ws_client = WorkflowStreamClient.create(client, handle.id)
216+
async for item in ws_client.topic("tokens", type=dict).subscribe(from_offset=0):
217+
print(item.data)
218+
if item.data.get("done"):
219+
break
220+
221+
print(await handle.result())
222+
```
223+
224+
### What's covered, and what isn't
225+
226+
`streaming_topic` wires up exactly **one** LangGraph stream mode: `stream_mode="custom"`, i.e. values written through `get_stream_writer()`. The other modes — `"messages"`, `"values"`, `"updates"`, `"debug"` — are **not** captured by `streaming_topic`. They aren't produced by node-side writers; LangGraph's orchestrator emits them as it walks the graph. The documented pattern is to **bridge `astream()` in the workflow** and republish each yielded chunk to a `WorkflowStream` topic yourself:
227+
228+
```python
229+
@workflow.defn
230+
class AstreamBridge:
231+
def __init__(self) -> None:
232+
self.stream = WorkflowStream()
233+
self.app = graph("g").compile()
234+
235+
@workflow.run
236+
async def run(self) -> None:
237+
topic = self.stream.topic("astream")
238+
async for chunk in self.app.astream({...}, stream_mode="messages"):
239+
topic.publish(chunk)
240+
topic.publish({"done": True})
241+
```
242+
243+
### Retry semantics
244+
245+
Streaming has **at-least-once** delivery per activity attempt. When an activity-wrapped node retries (transient failure, worker crash, etc.), the user function re-runs from scratch and re-publishes its writes — earlier publishes from the failed attempt are not rolled back. Subscribers should be ready to see duplicates and recover idempotently (e.g. dedupe on a sequence id you include in each chunk, or treat the stream as advisory and rely on the workflow's final result for state).
246+
146247
## Tracing
147248

148249
We recommend the [Temporal LangSmith Plugin](https://github.com/temporalio/sdk-python/tree/main/temporalio/contrib/langsmith) to trace your LangGraph Workflows and Activities.

temporalio/contrib/langgraph/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
__all__ = [
2121
"LangGraphPlugin",
22-
"entrypoint",
2322
"cache",
23+
"entrypoint",
2424
"graph",
2525
]

temporalio/contrib/langgraph/_activity.py

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""Activity wrappers for executing LangGraph nodes and tasks."""
22

3+
import asyncio
34
from collections.abc import Awaitable
45
from dataclasses import dataclass
6+
from datetime import timedelta
57
from inspect import iscoroutinefunction, signature
68
from typing import Any, Callable
79

@@ -19,6 +21,7 @@
1921
cache_lookup,
2022
cache_put,
2123
)
24+
from temporalio.contrib.workflow_streams import WorkflowStreamClient
2225

2326
# Per-run dedupe so we only warn once when a user passes a Store via
2427
# graph.compile(store=...) / @entrypoint(store=...). Cleared by
@@ -51,28 +54,54 @@ class ActivityOutput:
5154

5255
def wrap_activity(
5356
func: Callable,
57+
*,
58+
streaming_topic: str | None = None,
59+
streaming_batch_interval: timedelta = timedelta(milliseconds=100),
5460
) -> Callable[[ActivityInput], Awaitable[ActivityOutput]]:
5561
"""Wrap a function as a Temporal activity that handles LangGraph config and interrupts."""
56-
# Graph nodes declare `runtime: Runtime[Ctx]` in their signature; tasks
57-
# don't and instead reach for Runtime via get_runtime(). We re-inject the
58-
# reconstructed Runtime only when the user function asks.
5962
accepts_runtime = "runtime" in signature(func).parameters
6063

6164
async def wrapper(input: ActivityInput) -> ActivityOutput:
62-
runtime = set_langgraph_config(input.langgraph_config)
63-
kwargs = dict(input.kwargs)
64-
if accepts_runtime:
65-
kwargs["runtime"] = runtime
66-
try:
67-
if iscoroutinefunction(func):
68-
result = await func(*input.args, **kwargs)
69-
else:
70-
result = func(*input.args, **kwargs)
71-
if isinstance(result, Command):
72-
return ActivityOutput(langgraph_command=result)
73-
return ActivityOutput(result=result)
74-
except GraphInterrupt as e:
75-
return ActivityOutput(langgraph_interrupts=e.args[0])
65+
async def run(stream_writer: Callable[[Any], None] | None) -> ActivityOutput:
66+
# Sync funcs run on a thread (so the loop keeps flushing the
67+
# stream client mid-execution); marshal writer calls back to
68+
# the loop thread because the client's flush event is an
69+
# asyncio.Event and isn't safe to set off-thread.
70+
effective_writer = stream_writer
71+
if not iscoroutinefunction(func) and stream_writer is not None:
72+
loop = asyncio.get_running_loop()
73+
inner_writer = stream_writer
74+
75+
def thread_safe_writer(value: Any) -> None:
76+
loop.call_soon_threadsafe(inner_writer, value)
77+
78+
effective_writer = thread_safe_writer
79+
80+
runtime = set_langgraph_config(
81+
input.langgraph_config, stream_writer=effective_writer
82+
)
83+
kwargs = dict(input.kwargs)
84+
if accepts_runtime:
85+
kwargs["runtime"] = runtime
86+
87+
try:
88+
if iscoroutinefunction(func):
89+
result = await func(*input.args, **kwargs)
90+
else:
91+
result = await asyncio.to_thread(func, *input.args, **kwargs)
92+
if isinstance(result, Command):
93+
return ActivityOutput(langgraph_command=result)
94+
return ActivityOutput(result=result)
95+
except GraphInterrupt as e:
96+
return ActivityOutput(langgraph_interrupts=e.args[0])
97+
98+
if streaming_topic is None:
99+
return await run(stream_writer=None)
100+
async with WorkflowStreamClient.from_within_activity(
101+
batch_interval=streaming_batch_interval,
102+
) as client:
103+
topic = client.topic(streaming_topic)
104+
return await run(stream_writer=topic.publish)
76105

77106
return wrapper
78107

temporalio/contrib/langgraph/_interceptor.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from temporalio import workflow
1313
from temporalio.contrib.langgraph._activity import clear_store_warning
14+
from temporalio.contrib.workflow_streams._stream import _PUBLISH_SIGNAL
1415
from temporalio.worker import (
1516
ExecuteWorkflowInput,
1617
Interceptor,
@@ -30,17 +31,20 @@ def __init__(
3031
self,
3132
graphs: dict[str, StateGraph[Any, Any, Any, Any]],
3233
entrypoints: dict[str, Pregel[Any, Any, Any, Any]],
34+
streaming_topic: str | None = None,
3335
) -> None:
3436
"""Initialize with the graphs and entrypoints to scope to each workflow run."""
3537
self._graphs = graphs
3638
self._entrypoints = entrypoints
39+
self._streaming_topic = streaming_topic
3740

3841
def workflow_interceptor_class(
3942
self, input: WorkflowInterceptorClassInput
4043
) -> type[WorkflowInboundInterceptor]:
4144
"""Return the inbound interceptor class used to scope graphs per run."""
4245
graphs = self._graphs
4346
entrypoints = self._entrypoints
47+
streaming_topic = self._streaming_topic
4448

4549
class Inbound(WorkflowInboundInterceptor):
4650
def init(self, outbound: WorkflowOutboundInterceptor) -> None:
@@ -50,6 +54,18 @@ def init(self, outbound: WorkflowOutboundInterceptor) -> None:
5054
super().init(outbound)
5155

5256
async def execute_workflow(self, input: ExecuteWorkflowInput) -> Any:
57+
if (
58+
streaming_topic is not None
59+
and workflow.get_signal_handler(_PUBLISH_SIGNAL) is None
60+
):
61+
raise RuntimeError(
62+
f"LangGraphPlugin was configured with "
63+
f"streaming_topic={streaming_topic!r}, but workflow "
64+
f"{workflow.info().workflow_type!r} did not register a "
65+
f"WorkflowStream. Construct WorkflowStream() in the "
66+
f"workflow's @workflow.init (i.e. __init__) method so "
67+
f"streaming activities can publish to it."
68+
)
5369
try:
5470
return await self.next.execute_workflow(input)
5571
finally:

temporalio/contrib/langgraph/_langgraph_config.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# pyright: reportMissingTypeStubs=false
44

55
import dataclasses
6-
from typing import Any
6+
from typing import Any, Callable
77

88
from langchain_core.runnables.config import var_child_runnable_config
99
from langgraph._internal._constants import (
@@ -93,7 +93,11 @@ def get_langgraph_config() -> dict[str, Any]:
9393
}
9494

9595

96-
def set_langgraph_config(config: dict[str, Any]) -> Runtime:
96+
def set_langgraph_config(
97+
config: dict[str, Any],
98+
*,
99+
stream_writer: Callable[[Any], None] | None = None,
100+
) -> Runtime:
97101
"""Restore a LangGraph runnable config from a serialized dict.
98102
99103
Returns the reconstructed Runtime so callers can re-inject it into the
@@ -112,7 +116,7 @@ def get_null_resume(consume: bool = False) -> Any:
112116
execution_info_dict = config.get("execution_info")
113117
runtime = Runtime(
114118
context=config.get("context"),
115-
stream_writer=lambda _: None,
119+
stream_writer=stream_writer or (lambda _: None),
116120
previous=config.get("previous"),
117121
execution_info=(
118122
ExecutionInfo(**execution_info_dict) if execution_info_dict else None

0 commit comments

Comments
 (0)