-
Notifications
You must be signed in to change notification settings - Fork 843
feat: pass invocation_state to edge condition calls #2305
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,11 +16,12 @@ | |
|
|
||
| import asyncio | ||
| import copy | ||
| import inspect | ||
| import logging | ||
| import time | ||
| from collections.abc import AsyncIterator, Callable, Mapping | ||
| from dataclasses import dataclass, field | ||
| from typing import Any, cast | ||
| from typing import Any, Protocol, TypeGuard, cast, runtime_checkable | ||
|
|
||
| from opentelemetry import trace as trace_api | ||
|
|
||
|
|
@@ -62,6 +63,39 @@ | |
| _DEFAULT_GRAPH_ID = "default_graph" | ||
|
|
||
|
|
||
| @runtime_checkable | ||
| class EdgeConditionWithContext(Protocol): | ||
| """Protocol for edge conditions that receive invocation_state. | ||
|
|
||
| This allows conditions to make routing decisions based on runtime context | ||
| passed during graph invocation, such as feature flags, user roles, or | ||
| environment-specific configuration. | ||
|
|
||
| Designed with **kwargs for future extensibility without breaking changes. | ||
| """ | ||
|
|
||
| def __call__(self, state: "GraphState", *, invocation_state: dict[str, Any], **kwargs: Any) -> bool: | ||
| """Evaluate whether the edge should be traversed.""" | ||
| ... | ||
|
|
||
|
|
||
| LegacyEdgeCondition = Callable[["GraphState"], bool] | ||
| EdgeCondition = LegacyEdgeCondition | EdgeConditionWithContext | ||
|
|
||
|
|
||
| def _is_context_condition(condition: EdgeCondition) -> TypeGuard[EdgeConditionWithContext]: | ||
| """Check if a condition function accepts invocation_state parameter. | ||
|
|
||
| Uses inspect.signature() for reliable detection, returning a TypeGuard | ||
| so mypy can narrow the type at call sites. | ||
| """ | ||
| try: | ||
| sig = inspect.signature(condition) | ||
| return "invocation_state" in sig.parameters | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Issue: Suggestion: Cache the result per condition function using _condition_type_cache: dict[int, bool] = {}
def _is_context_condition(condition: EdgeCondition) -> TypeGuard[EdgeConditionWithContext]:
cond_id = id(condition)
if cond_id not in _condition_type_cache:
try:
sig = inspect.signature(condition)
_condition_type_cache[cond_id] = "invocation_state" in sig.parameters
except (ValueError, TypeError):
_condition_type_cache[cond_id] = False
return _condition_type_cache[cond_id]The issue spec also mentions this: "Consider caching signature inspection results for performance if needed." |
||
| except (ValueError, TypeError): | ||
| return False | ||
|
|
||
|
|
||
| @dataclass | ||
| class GraphState: | ||
| """Graph execution state. | ||
|
|
@@ -147,17 +181,28 @@ class GraphEdge: | |
|
|
||
| from_node: "GraphNode" | ||
| to_node: "GraphNode" | ||
| condition: Callable[[GraphState], bool] | None = None | ||
| condition: EdgeCondition | None = None | ||
|
|
||
| def __hash__(self) -> int: | ||
| """Return hash for GraphEdge based on from_node and to_node.""" | ||
| return hash((self.from_node.node_id, self.to_node.node_id)) | ||
|
|
||
| def should_traverse(self, state: GraphState) -> bool: | ||
| """Check if this edge should be traversed based on condition.""" | ||
| if self.condition is None: | ||
| def should_traverse(self, state: GraphState, *, invocation_state: dict[str, Any] | None = None) -> bool: | ||
| """Check if this edge should be traversed based on condition. | ||
|
|
||
| Args: | ||
| state: The current graph execution state. | ||
| invocation_state: Runtime context passed during graph invocation. | ||
| New-style conditions (EdgeConditionWithContext) receive this parameter. | ||
| Legacy conditions (Callable[[GraphState], bool]) are called with state only. | ||
| """ | ||
| condition = self.condition | ||
| if condition is None: | ||
| return True | ||
| return self.condition(state) | ||
| if _is_context_condition(condition): | ||
| return condition(state, invocation_state=invocation_state or {}) | ||
| legacy_condition = cast(LegacyEdgeCondition, condition) | ||
| return legacy_condition(state) | ||
|
|
||
|
|
||
| @dataclass | ||
|
|
@@ -276,9 +321,14 @@ def add_edge( | |
| self, | ||
| from_node: str | GraphNode, | ||
| to_node: str | GraphNode, | ||
| condition: Callable[[GraphState], bool] | None = None, | ||
| condition: EdgeCondition | None = None, | ||
| ) -> GraphEdge: | ||
| """Add an edge between two nodes with optional condition function that receives full GraphState.""" | ||
| """Add an edge between two nodes with optional condition function. | ||
|
|
||
| The condition can be either: | ||
| - A legacy callable: Callable[[GraphState], bool] - receives only graph state | ||
| - A new-style callable: EdgeConditionWithContext - receives graph state and invocation_state | ||
| """ | ||
|
|
||
| def resolve_node(node: str | GraphNode, node_type: str) -> GraphNode: | ||
| if isinstance(node, str): | ||
|
|
@@ -491,6 +541,7 @@ def __init__( | |
|
|
||
| self._resume_next_nodes: list[GraphNode] = [] | ||
| self._resume_from_session = False | ||
| self._current_invocation_state: dict[str, Any] = {} | ||
| self.id = id | ||
|
|
||
| run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self))) | ||
|
|
@@ -569,6 +620,8 @@ async def stream_async( | |
| if invocation_state is None: | ||
| invocation_state = {} | ||
|
|
||
| self._current_invocation_state = invocation_state | ||
|
|
||
| await self.hooks.invoke_callbacks_async(BeforeMultiAgentInvocationEvent(self, invocation_state)) | ||
|
|
||
| logger.debug("task=<%s> | starting graph execution", task) | ||
|
|
@@ -889,7 +942,7 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[ | |
| # Check if at least one incoming edge condition is satisfied | ||
| for edge in incoming_edges: | ||
| if edge.from_node in completed_batch: | ||
| if edge.should_traverse(self.state): | ||
| if edge.should_traverse(self.state, invocation_state=self._current_invocation_state): | ||
| logger.debug( | ||
| "from=<%s>, to=<%s> | edge ready via satisfied condition", edge.from_node.node_id, node.node_id | ||
| ) | ||
|
|
@@ -1125,7 +1178,7 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]: | |
| and edge.from_node in self.state.completed_nodes | ||
| and edge.from_node.node_id in self.state.results | ||
| ): | ||
| if edge.should_traverse(self.state): | ||
| if edge.should_traverse(self.state, invocation_state=self._current_invocation_state): | ||
| dependency_results[edge.from_node.node_id] = self.state.results[edge.from_node.node_id] | ||
|
|
||
| if not dependency_results: | ||
|
|
@@ -1201,6 +1254,7 @@ def serialize_state(self) -> dict[str, Any]: | |
| "next_nodes_to_execute": next_nodes, | ||
| "current_task": encode_bytes_values(self.state.task), | ||
| "execution_order": [n.node_id for n in self.state.execution_order], | ||
| "invocation_state": self._current_invocation_state, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Issue: Suggestion: Consider either:
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Addressed in |
||
| "_internal_state": { | ||
| "interrupt_state": self._interrupt_state.to_dict(), | ||
| }, | ||
|
|
@@ -1223,6 +1277,8 @@ def deserialize_state(self, payload: dict[str, Any]) -> None: | |
| internal_state = payload["_internal_state"] | ||
| self._interrupt_state = _InterruptState.from_dict(internal_state["interrupt_state"]) | ||
|
|
||
| self._current_invocation_state = payload.get("invocation_state", {}) | ||
|
|
||
| if not payload.get("next_nodes_to_execute"): | ||
| # Reset all nodes | ||
| for node in self.nodes.values(): | ||
|
|
@@ -1246,11 +1302,34 @@ def _compute_ready_nodes_for_resume(self) -> list[GraphNode]: | |
| incoming = [e for e in self.edges if e.to_node is node] | ||
| if not incoming: | ||
| ready_nodes.append(node) | ||
| elif all(e.from_node in completed_nodes and e.should_traverse(self.state) for e in incoming): | ||
| elif self._is_node_ready_for_resume(node, incoming, completed_nodes): | ||
| ready_nodes.append(node) | ||
|
|
||
| return ready_nodes | ||
|
|
||
| def _is_node_ready_for_resume( | ||
| self, | ||
| node: GraphNode, | ||
| incoming: list[GraphEdge], | ||
| completed_nodes: set[GraphNode], | ||
| ) -> bool: | ||
| """Check if a node is ready for resume, accounting for conditional edges. | ||
|
|
||
| A node is ready if all TRAVERSABLE incoming edges have their source completed. | ||
| Edges whose condition evaluates to False are excluded from the check — they | ||
| represent paths that were intentionally skipped. | ||
| """ | ||
| traversable_edges = [ | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Issue: In the Also, is calling traversable_edges = [
e for e in incoming
if e.should_traverse(self.state, invocation_state=self._current_invocation_state)
]Suggestion: The explicit
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Addressed in |
||
| e | ||
| for e in incoming | ||
| if e.condition is None or e.should_traverse(self.state, invocation_state=self._current_invocation_state) | ||
| ] | ||
|
|
||
| if not traversable_edges: | ||
| return False | ||
|
|
||
| return all(e.from_node in completed_nodes for e in traversable_edges) | ||
|
|
||
| def _from_dict(self, payload: dict[str, Any]) -> None: | ||
| self.state.status = Status(payload["status"]) | ||
| # Hydrate completed nodes & results | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Issue: The
@runtime_checkabledecorator is present but not used for dispatch —_is_context_conditionusesinspect.signature()instead ofisinstance(). This is the correct choice (sinceisinstancecan't distinguish the two signatures structurally), but the@runtime_checkabledecorator may mislead users into thinking they can useisinstance(cond, EdgeConditionWithContext)for the same purpose.Suggestion: Consider adding a brief comment noting that
@runtime_checkableis for type-checking ergonomics / documentation purposes only, not for runtime dispatch, sinceisinstancewill match both legacy and new-style conditions.