Skip to content

Commit 89b2be2

Browse files
authored
Merge branch 'main' into strands
2 parents a87789d + d53a604 commit 89b2be2

39 files changed

Lines changed: 3260 additions & 320 deletions

.github/CODEOWNERS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
# manage repo-wide concerns.
1313
/temporalio/contrib/common/ @temporalio/ai-sdk @temporalio/sdk
1414
/temporalio/contrib/google_adk_agents/ @temporalio/ai-sdk @temporalio/sdk
15+
/temporalio/contrib/langgraph/ @temporalio/ai-sdk @temporalio/sdk
1516
/temporalio/contrib/langsmith/ @temporalio/ai-sdk @temporalio/sdk
1617
/temporalio/contrib/openai_agents/ @temporalio/ai-sdk @temporalio/sdk
1718
/temporalio/contrib/strands/ @temporalio/ai-sdk @temporalio/sdk
1819
/tests/contrib/google_adk_agents/ @temporalio/ai-sdk @temporalio/sdk
20+
/tests/contrib/langgraph/ @temporalio/ai-sdk @temporalio/sdk
1921
/tests/contrib/langsmith/ @temporalio/ai-sdk @temporalio/sdk
2022
/tests/contrib/openai_agents/ @temporalio/ai-sdk @temporalio/sdk
2123
/tests/contrib/strands/ @temporalio/ai-sdk @temporalio/sdk

.github/workflows/omes.yml

Lines changed: 0 additions & 22 deletions
This file was deleted.

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@ class IPv4AddressJSONEncoder(AdvancedJSONEncoder):
432432
class IPv4AddressJSONTypeConverter(JSONTypeConverter):
433433
def to_typed_value(
434434
self, hint: Type, value: Any
435-
) -> Union[Optional[Any], _JSONTypeConverterUnhandled]:
435+
) -> Union[Optional[Any], JSONTypeConverterUnhandled]:
436436
if issubclass(hint, ipaddress.IPv4Address):
437437
return ipaddress.IPv4Address(value)
438438
return JSONTypeConverter.Unhandled

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ pydantic = ["pydantic>=2.0.0,<3"]
3131
openai-agents = ["openai-agents>=0.17.1", "mcp>=1.9.4, <2"]
3232
google-adk = ["google-adk>=1.27.0,<2"]
3333
langgraph = ["langgraph>=1.1.0"]
34-
langsmith = ["langsmith>=0.7.34,<0.8"]
34+
langsmith = ["langsmith>=0.7.34,<0.9"]
3535
lambda-worker-otel = [
3636
"opentelemetry-api>=1.11.1,<2",
3737
"opentelemetry-sdk>=1.11.1,<2",
@@ -79,7 +79,7 @@ dev = [
7979
"pytest-xdist>=3.6,<4",
8080
"moto[s3,server]>=5",
8181
"langgraph>=1.1.0",
82-
"langsmith>=0.7.34,<0.8",
82+
"langsmith>=0.7.34,<0.9",
8383
"setuptools<82",
8484
"opentelemetry-exporter-otlp-proto-grpc>=1.11.1,<2",
8585
"opentelemetry-semantic-conventions>=0.40b0,<1",

temporalio/client/_nexus.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,35 @@ async def start_operation(
611611
rpc_timeout: timedelta | None = None,
612612
) -> NexusOperationHandle[OutputT]: ...
613613

614+
# Overload for temporal_operation methods
615+
@overload
616+
@abstractmethod
617+
async def start_operation(
618+
self,
619+
operation: Callable[
620+
[
621+
NexusServiceType,
622+
temporalio.nexus.TemporalNexusStartOperationContext,
623+
temporalio.nexus.TemporalNexusClient,
624+
InputT,
625+
],
626+
Awaitable[temporalio.nexus.TemporalOperationResult[OutputT]],
627+
],
628+
arg: InputT,
629+
*,
630+
id: str,
631+
id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE,
632+
id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL,
633+
schedule_to_close_timeout: timedelta | None = None,
634+
schedule_to_start_timeout: timedelta | None = None,
635+
start_to_close_timeout: timedelta | None = None,
636+
search_attributes: temporalio.common.TypedSearchAttributes | None = None,
637+
summary: str | None = None,
638+
headers: Mapping[str, str] | None = None,
639+
rpc_metadata: Mapping[str, str | bytes] = {},
640+
rpc_timeout: timedelta | None = None,
641+
) -> NexusOperationHandle[OutputT]: ...
642+
614643
@abstractmethod
615644
async def start_operation(
616645
self,
@@ -804,6 +833,35 @@ async def execute_operation(
804833
rpc_timeout: timedelta | None = None,
805834
) -> OutputT: ...
806835

836+
# Overload for temporal_operation methods
837+
@overload
838+
@abstractmethod
839+
async def execute_operation(
840+
self,
841+
operation: Callable[
842+
[
843+
NexusServiceType,
844+
temporalio.nexus.TemporalNexusStartOperationContext,
845+
temporalio.nexus.TemporalNexusClient,
846+
InputT,
847+
],
848+
Awaitable[temporalio.nexus.TemporalOperationResult[OutputT]],
849+
],
850+
arg: InputT,
851+
*,
852+
id: str,
853+
id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE,
854+
id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL,
855+
schedule_to_close_timeout: timedelta | None = None,
856+
schedule_to_start_timeout: timedelta | None = None,
857+
start_to_close_timeout: timedelta | None = None,
858+
search_attributes: temporalio.common.TypedSearchAttributes | None = None,
859+
summary: str | None = None,
860+
headers: Mapping[str, str] | None = None,
861+
rpc_metadata: Mapping[str, str | bytes] = {},
862+
rpc_timeout: timedelta | None = None,
863+
) -> OutputT: ...
864+
807865
@abstractmethod
808866
async def execute_operation(
809867
self,

temporalio/contrib/langgraph/README.md

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,107 @@ await g.ainvoke({...}, context=Context(user_id="alice"))
143143

144144
Your `context` object must be serializable by the configured Temporal payload converter, since it crosses the Activity boundary.
145145

146+
## Streaming
147+
148+
When `streaming_topic` is set on `LangGraphPlugin`, calls to `langgraph.config.get_stream_writer()` inside a node publish to the named topic on the workflow's [`WorkflowStream`](https://github.com/temporalio/sdk-python/tree/main/temporalio/contrib/workflow_streams). Activity-side nodes publish via `WorkflowStreamClient` (a signal carrying batched items, controlled by `streaming_batch_interval`); workflow-side nodes publish synchronously to the in-workflow stream (no signal). External subscribers consume the stream with `WorkflowStreamClient.create(...).topic(...).subscribe(...)`.
149+
150+
The workflow **must** construct `WorkflowStream()` in its `@workflow.init` (i.e. `__init__`)
151+
152+
```python
153+
from datetime import timedelta
154+
from typing import Any
155+
156+
from langgraph.config import get_stream_writer
157+
from langgraph.graph import START, StateGraph
158+
from typing_extensions import TypedDict
159+
160+
from temporalio import workflow
161+
from temporalio.client import Client
162+
from temporalio.contrib.langgraph import LangGraphPlugin, graph
163+
from temporalio.contrib.workflow_streams import WorkflowStream, WorkflowStreamClient
164+
from temporalio.worker import Worker
165+
166+
167+
class State(TypedDict):
168+
value: str
169+
170+
171+
async def token_node(state: State) -> dict[str, str]:
172+
writer = get_stream_writer()
173+
for token in ["hello", " ", "world"]:
174+
writer({"token": token})
175+
writer({"done": True})
176+
return {"value": "hello world"}
177+
178+
179+
@workflow.defn
180+
class StreamingWorkflow:
181+
def __init__(self) -> None:
182+
# Required when streaming_topic is set on the plugin.
183+
_ = WorkflowStream()
184+
self.app = graph("streaming").compile()
185+
186+
@workflow.run
187+
async def run(self) -> str:
188+
result = await self.app.ainvoke({"value": ""})
189+
return result["value"]
190+
191+
192+
async def main(client: Client) -> None:
193+
g = StateGraph(State)
194+
g.add_node("token_node", token_node, metadata={"execute_in": "activity"})
195+
g.add_edge(START, "token_node")
196+
197+
async with Worker(
198+
client,
199+
task_queue="streaming-tq",
200+
workflows=[StreamingWorkflow],
201+
plugins=[
202+
LangGraphPlugin(
203+
graphs={"streaming": g},
204+
default_activity_options={
205+
"start_to_close_timeout": timedelta(seconds=10)
206+
},
207+
streaming_topic="tokens",
208+
)
209+
],
210+
):
211+
handle = await client.start_workflow(
212+
StreamingWorkflow.run, id="streaming-wf", task_queue="streaming-tq"
213+
)
214+
215+
ws_client = WorkflowStreamClient.create(client, handle.id)
216+
async for item in ws_client.topic("tokens", type=dict).subscribe(from_offset=0):
217+
print(item.data)
218+
if item.data.get("done"):
219+
break
220+
221+
print(await handle.result())
222+
```
223+
224+
### What's covered, and what isn't
225+
226+
`streaming_topic` wires up exactly **one** LangGraph stream mode: `stream_mode="custom"`, i.e. values written through `get_stream_writer()`. The other modes — `"messages"`, `"values"`, `"updates"`, `"debug"` — are **not** captured by `streaming_topic`. They aren't produced by node-side writers; LangGraph's orchestrator emits them as it walks the graph. The documented pattern is to **bridge `astream()` in the workflow** and republish each yielded chunk to a `WorkflowStream` topic yourself:
227+
228+
```python
229+
@workflow.defn
230+
class AstreamBridge:
231+
def __init__(self) -> None:
232+
self.stream = WorkflowStream()
233+
self.app = graph("g").compile()
234+
235+
@workflow.run
236+
async def run(self) -> None:
237+
topic = self.stream.topic("astream")
238+
async for chunk in self.app.astream({...}, stream_mode="messages"):
239+
topic.publish(chunk)
240+
topic.publish({"done": True})
241+
```
242+
243+
### Retry semantics
244+
245+
Streaming has **at-least-once** delivery per activity attempt. When an activity-wrapped node retries (transient failure, worker crash, etc.), the user function re-runs from scratch and re-publishes its writes — earlier publishes from the failed attempt are not rolled back. Subscribers should be ready to see duplicates and recover idempotently (e.g. dedupe on a sequence id you include in each chunk, or treat the stream as advisory and rely on the workflow's final result for state).
246+
146247
## Tracing
147248

148249
We recommend the [Temporal LangSmith Plugin](https://github.com/temporalio/sdk-python/tree/main/temporalio/contrib/langsmith) to trace your LangGraph Workflows and Activities.

temporalio/contrib/langgraph/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
__all__ = [
2121
"LangGraphPlugin",
22-
"entrypoint",
2322
"cache",
23+
"entrypoint",
2424
"graph",
2525
]

temporalio/contrib/langgraph/_activity.py

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""Activity wrappers for executing LangGraph nodes and tasks."""
22

3+
import asyncio
34
from collections.abc import Awaitable
45
from dataclasses import dataclass
6+
from datetime import timedelta
57
from inspect import iscoroutinefunction, signature
68
from typing import Any, Callable
79

@@ -19,6 +21,7 @@
1921
cache_lookup,
2022
cache_put,
2123
)
24+
from temporalio.contrib.workflow_streams import WorkflowStreamClient
2225

2326
# Per-run dedupe so we only warn once when a user passes a Store via
2427
# graph.compile(store=...) / @entrypoint(store=...). Cleared by
@@ -51,28 +54,54 @@ class ActivityOutput:
5154

5255
def wrap_activity(
5356
func: Callable,
57+
*,
58+
streaming_topic: str | None = None,
59+
streaming_batch_interval: timedelta = timedelta(milliseconds=100),
5460
) -> Callable[[ActivityInput], Awaitable[ActivityOutput]]:
5561
"""Wrap a function as a Temporal activity that handles LangGraph config and interrupts."""
56-
# Graph nodes declare `runtime: Runtime[Ctx]` in their signature; tasks
57-
# don't and instead reach for Runtime via get_runtime(). We re-inject the
58-
# reconstructed Runtime only when the user function asks.
5962
accepts_runtime = "runtime" in signature(func).parameters
6063

6164
async def wrapper(input: ActivityInput) -> ActivityOutput:
62-
runtime = set_langgraph_config(input.langgraph_config)
63-
kwargs = dict(input.kwargs)
64-
if accepts_runtime:
65-
kwargs["runtime"] = runtime
66-
try:
67-
if iscoroutinefunction(func):
68-
result = await func(*input.args, **kwargs)
69-
else:
70-
result = func(*input.args, **kwargs)
71-
if isinstance(result, Command):
72-
return ActivityOutput(langgraph_command=result)
73-
return ActivityOutput(result=result)
74-
except GraphInterrupt as e:
75-
return ActivityOutput(langgraph_interrupts=e.args[0])
65+
async def run(stream_writer: Callable[[Any], None] | None) -> ActivityOutput:
66+
# Sync funcs run on a thread (so the loop keeps flushing the
67+
# stream client mid-execution); marshal writer calls back to
68+
# the loop thread because the client's flush event is an
69+
# asyncio.Event and isn't safe to set off-thread.
70+
effective_writer = stream_writer
71+
if not iscoroutinefunction(func) and stream_writer is not None:
72+
loop = asyncio.get_running_loop()
73+
inner_writer = stream_writer
74+
75+
def thread_safe_writer(value: Any) -> None:
76+
loop.call_soon_threadsafe(inner_writer, value)
77+
78+
effective_writer = thread_safe_writer
79+
80+
runtime = set_langgraph_config(
81+
input.langgraph_config, stream_writer=effective_writer
82+
)
83+
kwargs = dict(input.kwargs)
84+
if accepts_runtime:
85+
kwargs["runtime"] = runtime
86+
87+
try:
88+
if iscoroutinefunction(func):
89+
result = await func(*input.args, **kwargs)
90+
else:
91+
result = await asyncio.to_thread(func, *input.args, **kwargs)
92+
if isinstance(result, Command):
93+
return ActivityOutput(langgraph_command=result)
94+
return ActivityOutput(result=result)
95+
except GraphInterrupt as e:
96+
return ActivityOutput(langgraph_interrupts=e.args[0])
97+
98+
if streaming_topic is None:
99+
return await run(stream_writer=None)
100+
async with WorkflowStreamClient.from_within_activity(
101+
batch_interval=streaming_batch_interval,
102+
) as client:
103+
topic = client.topic(streaming_topic)
104+
return await run(stream_writer=topic.publish)
76105

77106
return wrapper
78107

temporalio/contrib/langgraph/_interceptor.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from temporalio import workflow
1313
from temporalio.contrib.langgraph._activity import clear_store_warning
14+
from temporalio.contrib.workflow_streams._stream import _PUBLISH_SIGNAL
1415
from temporalio.worker import (
1516
ExecuteWorkflowInput,
1617
Interceptor,
@@ -30,17 +31,20 @@ def __init__(
3031
self,
3132
graphs: dict[str, StateGraph[Any, Any, Any, Any]],
3233
entrypoints: dict[str, Pregel[Any, Any, Any, Any]],
34+
streaming_topic: str | None = None,
3335
) -> None:
3436
"""Initialize with the graphs and entrypoints to scope to each workflow run."""
3537
self._graphs = graphs
3638
self._entrypoints = entrypoints
39+
self._streaming_topic = streaming_topic
3740

3841
def workflow_interceptor_class(
3942
self, input: WorkflowInterceptorClassInput
4043
) -> type[WorkflowInboundInterceptor]:
4144
"""Return the inbound interceptor class used to scope graphs per run."""
4245
graphs = self._graphs
4346
entrypoints = self._entrypoints
47+
streaming_topic = self._streaming_topic
4448

4549
class Inbound(WorkflowInboundInterceptor):
4650
def init(self, outbound: WorkflowOutboundInterceptor) -> None:
@@ -50,6 +54,18 @@ def init(self, outbound: WorkflowOutboundInterceptor) -> None:
5054
super().init(outbound)
5155

5256
async def execute_workflow(self, input: ExecuteWorkflowInput) -> Any:
57+
if (
58+
streaming_topic is not None
59+
and workflow.get_signal_handler(_PUBLISH_SIGNAL) is None
60+
):
61+
raise RuntimeError(
62+
f"LangGraphPlugin was configured with "
63+
f"streaming_topic={streaming_topic!r}, but workflow "
64+
f"{workflow.info().workflow_type!r} did not register a "
65+
f"WorkflowStream. Construct WorkflowStream() in the "
66+
f"workflow's @workflow.init (i.e. __init__) method so "
67+
f"streaming activities can publish to it."
68+
)
5369
try:
5470
return await self.next.execute_workflow(input)
5571
finally:

0 commit comments

Comments
 (0)