Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/strands/multiagent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
"""

from .base import MultiAgentBase, MultiAgentResult, Status
from .graph import GraphBuilder, GraphResult
from .graph import EdgeCondition, EdgeConditionWithContext, GraphBuilder, GraphResult
from .swarm import Swarm, SwarmResult

__all__ = [
"EdgeCondition",
"EdgeConditionWithContext",
"GraphBuilder",
"GraphResult",
"MultiAgentBase",
Expand Down
145 changes: 134 additions & 11 deletions src/strands/multiagent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@

import asyncio
import copy
import inspect
import json
import logging
import time
import weakref
from collections.abc import AsyncIterator, Callable, Mapping
from dataclasses import dataclass, field
from typing import Any, cast
from typing import Any, Protocol, TypeGuard, cast

from opentelemetry import trace as trace_api

Expand Down Expand Up @@ -62,6 +65,57 @@
_DEFAULT_GRAPH_ID = "default_graph"


class EdgeConditionWithContext(Protocol):
"""Protocol for edge conditions that receive invocation_state.

This allows conditions to make routing decisions based on runtime context
passed during graph invocation, such as feature flags, user roles, or
environment-specific configuration.

Designed with **kwargs for future extensibility without breaking changes.

Note: not @runtime_checkable — isinstance() cannot distinguish callable signatures
structurally; use _is_context_condition() for dispatch instead.
"""

def __call__(self, state: "GraphState", *, invocation_state: dict[str, Any], **kwargs: Any) -> bool:
"""Evaluate whether the edge should be traversed."""
...


LegacyEdgeCondition = Callable[["GraphState"], bool]
EdgeCondition = LegacyEdgeCondition | EdgeConditionWithContext

# GIL-protected: concurrent async graphs may read/write simultaneously, but
# CPython's GIL ensures dict mutation is atomic. Ephemeral callables (e.g.
# lambdas recreated per-call) will bypass the cache — this is benign; the
# fallback path is a single inspect.signature() call.
_context_condition_cache: weakref.WeakKeyDictionary[EdgeCondition, bool] = weakref.WeakKeyDictionary()


def _is_context_condition(condition: EdgeCondition) -> TypeGuard[EdgeConditionWithContext]:
"""Check if a condition function accepts invocation_state parameter.

Uses inspect.signature() for reliable detection, returning a TypeGuard
so mypy can narrow the type at call sites. Results are cached per condition
using weak references so entries are evicted when the function is collected.
"""
try:
return _context_condition_cache[condition]
except (KeyError, TypeError):
pass
try:
sig = inspect.signature(condition)
result = "invocation_state" in sig.parameters
except (ValueError, TypeError):
result = False
try:
_context_condition_cache[condition] = result
except TypeError:
pass
return result


@dataclass
class GraphState:
"""Graph execution state.
Expand Down Expand Up @@ -147,17 +201,28 @@ class GraphEdge:

from_node: "GraphNode"
to_node: "GraphNode"
condition: Callable[[GraphState], bool] | None = None
condition: EdgeCondition | None = None

def __hash__(self) -> int:
"""Return hash for GraphEdge based on from_node and to_node."""
return hash((self.from_node.node_id, self.to_node.node_id))

def should_traverse(self, state: GraphState) -> bool:
"""Check if this edge should be traversed based on condition."""
if self.condition is None:
def should_traverse(self, state: GraphState, *, invocation_state: dict[str, Any] | None = None) -> bool:
"""Check if this edge should be traversed based on condition.

Args:
state: The current graph execution state.
invocation_state: Runtime context passed during graph invocation.
New-style conditions (EdgeConditionWithContext) receive this parameter.
Legacy conditions (Callable[[GraphState], bool]) are called with state only.
"""
condition = self.condition
if condition is None:
return True
return self.condition(state)
if _is_context_condition(condition):
return condition(state, invocation_state=invocation_state or {})
legacy_condition = cast(LegacyEdgeCondition, condition)
return legacy_condition(state)


@dataclass
Expand Down Expand Up @@ -276,9 +341,14 @@ def add_edge(
self,
from_node: str | GraphNode,
to_node: str | GraphNode,
condition: Callable[[GraphState], bool] | None = None,
condition: EdgeCondition | None = None,
) -> GraphEdge:
"""Add an edge between two nodes with optional condition function that receives full GraphState."""
"""Add an edge between two nodes with optional condition function.

The condition can be either:
- A legacy callable: Callable[[GraphState], bool] - receives only graph state
- A new-style callable: EdgeConditionWithContext - receives graph state and invocation_state
"""

def resolve_node(node: str | GraphNode, node_type: str) -> GraphNode:
if isinstance(node, str):
Expand Down Expand Up @@ -491,6 +561,7 @@ def __init__(

self._resume_next_nodes: list[GraphNode] = []
self._resume_from_session = False
self._current_invocation_state: dict[str, Any] = {}
self.id = id

run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self)))
Expand Down Expand Up @@ -569,6 +640,10 @@ async def stream_async(
if invocation_state is None:
invocation_state = {}

if self.session_manager is not None:
self._validate_invocation_state(invocation_state)
self._current_invocation_state = invocation_state

await self.hooks.invoke_callbacks_async(BeforeMultiAgentInvocationEvent(self, invocation_state))

logger.debug("task=<%s> | starting graph execution", task)
Expand Down Expand Up @@ -889,7 +964,7 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[
# Check if at least one incoming edge condition is satisfied
for edge in incoming_edges:
if edge.from_node in completed_batch:
if edge.should_traverse(self.state):
if edge.should_traverse(self.state, invocation_state=self._current_invocation_state):
logger.debug(
"from=<%s>, to=<%s> | edge ready via satisfied condition", edge.from_node.node_id, node.node_id
)
Expand Down Expand Up @@ -1125,7 +1200,7 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]:
and edge.from_node in self.state.completed_nodes
and edge.from_node.node_id in self.state.results
):
if edge.should_traverse(self.state):
if edge.should_traverse(self.state, invocation_state=self._current_invocation_state):
dependency_results[edge.from_node.node_id] = self.state.results[edge.from_node.node_id]

if not dependency_results:
Expand Down Expand Up @@ -1186,6 +1261,20 @@ def _build_result(self, interrupts: list[Interrupt]) -> GraphResult:
interrupts=interrupts,
)

@staticmethod
def _validate_invocation_state(invocation_state: dict[str, Any]) -> None:
"""Validate that invocation_state is JSON-serializable.

Raises:
TypeError: If invocation_state contains non-JSON-serializable values.
"""
try:
json.dumps(invocation_state)
except (TypeError, ValueError) as e:
raise TypeError(
f"invocation_state must be JSON-serializable for session persistence: {e}"
) from e

def serialize_state(self) -> dict[str, Any]:
"""Serialize the current graph state to a dictionary."""
compute_nodes = self._compute_ready_nodes_for_resume()
Expand All @@ -1201,6 +1290,7 @@ def serialize_state(self) -> dict[str, Any]:
"next_nodes_to_execute": next_nodes,
"current_task": encode_bytes_values(self.state.task),
"execution_order": [n.node_id for n in self.state.execution_order],
"invocation_state": self._current_invocation_state,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: invocation_state is serialized directly into the session payload without validation. If a user passes non-JSON-serializable objects (e.g., class instances, functions) in invocation_state, this will fail silently or raise an unclear error during serialization.

Suggestion: Consider either:

  1. Documenting that invocation_state values must be JSON-serializable, or
  2. Adding a validation/guard in serialize_state() that provides a clear error message if serialization fails due to non-serializable invocation_state values.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 7cc0d04 — added _validate_invocation_state() that calls json.dumps() and raises a clear TypeError if the value isn't serializable. It's gated on session_manager is not None (serialization only matters when sessions persist), and also validated symmetrically on the deserialization path in deserialize_state.

"_internal_state": {
"interrupt_state": self._interrupt_state.to_dict(),
},
Expand All @@ -1223,6 +1313,10 @@ def deserialize_state(self, payload: dict[str, Any]) -> None:
internal_state = payload["_internal_state"]
self._interrupt_state = _InterruptState.from_dict(internal_state["interrupt_state"])

invocation_state = payload.get("invocation_state", {})
self._validate_invocation_state(invocation_state)
self._current_invocation_state = invocation_state

if not payload.get("next_nodes_to_execute"):
# Reset all nodes
for node in self.nodes.values():
Expand All @@ -1246,11 +1340,40 @@ def _compute_ready_nodes_for_resume(self) -> list[GraphNode]:
incoming = [e for e in self.edges if e.to_node is node]
if not incoming:
ready_nodes.append(node)
elif all(e.from_node in completed_nodes and e.should_traverse(self.state) for e in incoming):
elif self._is_node_ready_for_resume(node, incoming, completed_nodes):
ready_nodes.append(node)

return ready_nodes

def _is_node_ready_for_resume(
self,
node: GraphNode,
incoming: list[GraphEdge],
completed_nodes: set[GraphNode],
) -> bool:
"""Check if a node is ready for resume, accounting for conditional edges.

A node is ready if all TRAVERSABLE incoming edges have their source completed.
Edges whose condition evaluates to False are excluded from the check — they
represent paths that were intentionally skipped.

Re-evaluates conditions (rather than caching traversal results) intentionally:
invocation_state may change between invocations, so conditions must reflect
current runtime context. This means condition logic changes between serialize
and resume will also take effect — consistent with the graph being defined in code.
"""
traversable_edges = [
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: In the _is_node_ready_for_resume method, the logic evaluates all edge conditions to determine traversability. However, since invocation_state is restored from the serialized payload, but the edge conditions themselves are defined in code — if the graph code changes between serialize and deserialize (e.g., condition logic updated), the behavior could be subtly different on resume. This is inherent to the design but worth noting.

Also, is calling e.should_traverse() here redundant with the e.condition is None check? should_traverse already returns True when condition is None, so the filter could be simplified to just e.should_traverse(...):

traversable_edges = [
    e for e in incoming
    if e.should_traverse(self.state, invocation_state=self._current_invocation_state)
]

Suggestion: The explicit e.condition is None check is likely a micro-optimization to avoid the inspect.signature() call for unconditional edges. If caching is added to _is_context_condition, this optimization becomes unnecessary and the code can be simplified.

Copy link
Copy Markdown
Author

@yananym yananym May 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 7cc0d04. You're right that with the WeakKeyDictionary cache now in place, the e.condition is None short-circuit is no longer necessary for performance. I kept it because it still avoids the method call + cache lookup entirely for unconditional edges, and makes the "unconditional edges always traverse" semantics explicit at the call site. Added an inline comment explaining the intent. Happy to simplify if you'd prefer the terser form.
On the serialize/deserialize behavior divergence — added a docstring to _compute_ready_nodes_for_resume acknowledging this explicitly: re-evaluation is intentional since invocation_state may differ between invocations, and condition logic changes taking effect on resume is consistent with the graph being defined in code rather than serialized.

e
for e in incoming
# Short-circuit: skip signature inspection + cache lookup for unconditional edges.
if e.condition is None or e.should_traverse(self.state, invocation_state=self._current_invocation_state)
]

if not traversable_edges:
return False

return all(e.from_node in completed_nodes for e in traversable_edges)

def _from_dict(self, payload: dict[str, Any]) -> None:
self.state.status = Status(payload["status"])
# Hydrate completed nodes & results
Expand Down
Loading