Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/strands/multiagent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
"""

from .base import MultiAgentBase, MultiAgentResult, Status
from .graph import GraphBuilder, GraphResult
from .graph import EdgeCondition, EdgeConditionWithContext, GraphBuilder, GraphResult
from .swarm import Swarm, SwarmResult

__all__ = [
"EdgeCondition",
"EdgeConditionWithContext",
"GraphBuilder",
"GraphResult",
"MultiAgentBase",
Expand Down
101 changes: 90 additions & 11 deletions src/strands/multiagent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@

import asyncio
import copy
import inspect
import logging
import time
from collections.abc import AsyncIterator, Callable, Mapping
from dataclasses import dataclass, field
from typing import Any, cast
from typing import Any, Protocol, TypeGuard, cast, runtime_checkable

from opentelemetry import trace as trace_api

Expand Down Expand Up @@ -62,6 +63,39 @@
_DEFAULT_GRAPH_ID = "default_graph"


@runtime_checkable
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: The @runtime_checkable decorator is present but not used for dispatch — _is_context_condition uses inspect.signature() instead of isinstance(). This is the correct choice (since isinstance can't distinguish the two signatures structurally), but the @runtime_checkable decorator may mislead users into thinking they can use isinstance(cond, EdgeConditionWithContext) for the same purpose.

Suggestion: Consider adding a brief comment noting that @runtime_checkable is for type-checking ergonomics / documentation purposes only, not for runtime dispatch, since isinstance will match both legacy and new-style conditions.

class EdgeConditionWithContext(Protocol):
"""Protocol for edge conditions that receive invocation_state.

This allows conditions to make routing decisions based on runtime context
passed during graph invocation, such as feature flags, user roles, or
environment-specific configuration.

Designed with **kwargs for future extensibility without breaking changes.
"""

def __call__(self, state: "GraphState", *, invocation_state: dict[str, Any], **kwargs: Any) -> bool:
"""Evaluate whether the edge should be traversed."""
...


LegacyEdgeCondition = Callable[["GraphState"], bool]
EdgeCondition = LegacyEdgeCondition | EdgeConditionWithContext


def _is_context_condition(condition: EdgeCondition) -> TypeGuard[EdgeConditionWithContext]:
"""Check if a condition function accepts invocation_state parameter.

Uses inspect.signature() for reliable detection, returning a TypeGuard
so mypy can narrow the type at call sites.
"""
try:
sig = inspect.signature(condition)
return "invocation_state" in sig.parameters
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: inspect.signature() is called on every should_traverse() invocation with no caching. In graphs with many edges or cyclic execution, this could become a performance bottleneck since signature introspection is relatively expensive.

Suggestion: Cache the result per condition function using functools.lru_cache or a simple dict keyed by id(condition):

_condition_type_cache: dict[int, bool] = {}

def _is_context_condition(condition: EdgeCondition) -> TypeGuard[EdgeConditionWithContext]:
    cond_id = id(condition)
    if cond_id not in _condition_type_cache:
        try:
            sig = inspect.signature(condition)
            _condition_type_cache[cond_id] = "invocation_state" in sig.parameters
        except (ValueError, TypeError):
            _condition_type_cache[cond_id] = False
    return _condition_type_cache[cond_id]

The issue spec also mentions this: "Consider caching signature inspection results for performance if needed."

except (ValueError, TypeError):
return False


@dataclass
class GraphState:
"""Graph execution state.
Expand Down Expand Up @@ -147,17 +181,28 @@ class GraphEdge:

from_node: "GraphNode"
to_node: "GraphNode"
condition: Callable[[GraphState], bool] | None = None
condition: EdgeCondition | None = None

def __hash__(self) -> int:
"""Return hash for GraphEdge based on from_node and to_node."""
return hash((self.from_node.node_id, self.to_node.node_id))

def should_traverse(self, state: GraphState) -> bool:
"""Check if this edge should be traversed based on condition."""
if self.condition is None:
def should_traverse(self, state: GraphState, *, invocation_state: dict[str, Any] | None = None) -> bool:
"""Check if this edge should be traversed based on condition.

Args:
state: The current graph execution state.
invocation_state: Runtime context passed during graph invocation.
New-style conditions (EdgeConditionWithContext) receive this parameter.
Legacy conditions (Callable[[GraphState], bool]) are called with state only.
"""
condition = self.condition
if condition is None:
return True
return self.condition(state)
if _is_context_condition(condition):
return condition(state, invocation_state=invocation_state or {})
legacy_condition = cast(LegacyEdgeCondition, condition)
return legacy_condition(state)


@dataclass
Expand Down Expand Up @@ -276,9 +321,14 @@ def add_edge(
self,
from_node: str | GraphNode,
to_node: str | GraphNode,
condition: Callable[[GraphState], bool] | None = None,
condition: EdgeCondition | None = None,
) -> GraphEdge:
"""Add an edge between two nodes with optional condition function that receives full GraphState."""
"""Add an edge between two nodes with optional condition function.

The condition can be either:
- A legacy callable: Callable[[GraphState], bool] - receives only graph state
- A new-style callable: EdgeConditionWithContext - receives graph state and invocation_state
"""

def resolve_node(node: str | GraphNode, node_type: str) -> GraphNode:
if isinstance(node, str):
Expand Down Expand Up @@ -491,6 +541,7 @@ def __init__(

self._resume_next_nodes: list[GraphNode] = []
self._resume_from_session = False
self._current_invocation_state: dict[str, Any] = {}
self.id = id

run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self)))
Expand Down Expand Up @@ -569,6 +620,8 @@ async def stream_async(
if invocation_state is None:
invocation_state = {}

self._current_invocation_state = invocation_state

await self.hooks.invoke_callbacks_async(BeforeMultiAgentInvocationEvent(self, invocation_state))

logger.debug("task=<%s> | starting graph execution", task)
Expand Down Expand Up @@ -889,7 +942,7 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[
# Check if at least one incoming edge condition is satisfied
for edge in incoming_edges:
if edge.from_node in completed_batch:
if edge.should_traverse(self.state):
if edge.should_traverse(self.state, invocation_state=self._current_invocation_state):
logger.debug(
"from=<%s>, to=<%s> | edge ready via satisfied condition", edge.from_node.node_id, node.node_id
)
Expand Down Expand Up @@ -1125,7 +1178,7 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]:
and edge.from_node in self.state.completed_nodes
and edge.from_node.node_id in self.state.results
):
if edge.should_traverse(self.state):
if edge.should_traverse(self.state, invocation_state=self._current_invocation_state):
dependency_results[edge.from_node.node_id] = self.state.results[edge.from_node.node_id]

if not dependency_results:
Expand Down Expand Up @@ -1201,6 +1254,7 @@ def serialize_state(self) -> dict[str, Any]:
"next_nodes_to_execute": next_nodes,
"current_task": encode_bytes_values(self.state.task),
"execution_order": [n.node_id for n in self.state.execution_order],
"invocation_state": self._current_invocation_state,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: invocation_state is serialized directly into the session payload without validation. If a user passes non-JSON-serializable objects (e.g., class instances, functions) in invocation_state, this will fail silently or raise an unclear error during serialization.

Suggestion: Consider either:

  1. Documenting that invocation_state values must be JSON-serializable, or
  2. Adding a validation/guard in serialize_state() that provides a clear error message if serialization fails due to non-serializable invocation_state values.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 7cc0d04 — added _validate_invocation_state() that calls json.dumps() and raises a clear TypeError if the value isn't serializable. It's gated on session_manager is not None (serialization only matters when sessions persist), and also validated symmetrically on the deserialization path in deserialize_state.

"_internal_state": {
"interrupt_state": self._interrupt_state.to_dict(),
},
Expand All @@ -1223,6 +1277,8 @@ def deserialize_state(self, payload: dict[str, Any]) -> None:
internal_state = payload["_internal_state"]
self._interrupt_state = _InterruptState.from_dict(internal_state["interrupt_state"])

self._current_invocation_state = payload.get("invocation_state", {})

if not payload.get("next_nodes_to_execute"):
# Reset all nodes
for node in self.nodes.values():
Expand All @@ -1246,11 +1302,34 @@ def _compute_ready_nodes_for_resume(self) -> list[GraphNode]:
incoming = [e for e in self.edges if e.to_node is node]
if not incoming:
ready_nodes.append(node)
elif all(e.from_node in completed_nodes and e.should_traverse(self.state) for e in incoming):
elif self._is_node_ready_for_resume(node, incoming, completed_nodes):
ready_nodes.append(node)

return ready_nodes

def _is_node_ready_for_resume(
self,
node: GraphNode,
incoming: list[GraphEdge],
completed_nodes: set[GraphNode],
) -> bool:
"""Check if a node is ready for resume, accounting for conditional edges.

A node is ready if all TRAVERSABLE incoming edges have their source completed.
Edges whose condition evaluates to False are excluded from the check — they
represent paths that were intentionally skipped.
"""
traversable_edges = [
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: In the _is_node_ready_for_resume method, the logic evaluates all edge conditions to determine traversability. However, since invocation_state is restored from the serialized payload, but the edge conditions themselves are defined in code — if the graph code changes between serialize and deserialize (e.g., condition logic updated), the behavior could be subtly different on resume. This is inherent to the design but worth noting.

Also, is calling e.should_traverse() here redundant with the e.condition is None check? should_traverse already returns True when condition is None, so the filter could be simplified to just e.should_traverse(...):

traversable_edges = [
    e for e in incoming
    if e.should_traverse(self.state, invocation_state=self._current_invocation_state)
]

Suggestion: The explicit e.condition is None check is likely a micro-optimization to avoid the inspect.signature() call for unconditional edges. If caching is added to _is_context_condition, this optimization becomes unnecessary and the code can be simplified.

Copy link
Copy Markdown
Author

@yananym yananym May 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 7cc0d04. You're right that with the WeakKeyDictionary cache now in place, the e.condition is None short-circuit is no longer necessary for performance. I kept it because it still avoids the method call + cache lookup entirely for unconditional edges, and makes the "unconditional edges always traverse" semantics explicit at the call site. Added an inline comment explaining the intent. Happy to simplify if you'd prefer the terser form.
On the serialize/deserialize behavior divergence — added a docstring to _compute_ready_nodes_for_resume acknowledging this explicitly: re-evaluation is intentional since invocation_state may differ between invocations, and condition logic changes taking effect on resume is consistent with the graph being defined in code rather than serialized.

e
for e in incoming
if e.condition is None or e.should_traverse(self.state, invocation_state=self._current_invocation_state)
]

if not traversable_edges:
return False

return all(e.from_node in completed_nodes for e in traversable_edges)

def _from_dict(self, payload: dict[str, Any]) -> None:
self.state.status = Status(payload["status"])
# Hydrate completed nodes & results
Expand Down
Loading