Skip to content

Commit e87f80b

Browse files
committed
Removed utils related to AgentBreakpoint and ToolBreakpoint
1 parent 2515aaf commit e87f80b

7 files changed

Lines changed: 32 additions & 560 deletions

File tree

haystack/core/errors.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from typing import Any
66

7-
from haystack.dataclasses.breakpoints import AgentBreakpoint, Breakpoint, PipelineSnapshot, ToolBreakpoint
7+
from haystack.dataclasses.breakpoints import Breakpoint, PipelineSnapshot
88

99

1010
class PipelineError(Exception):
@@ -115,7 +115,7 @@ def __init__(
115115
pipeline_snapshot: PipelineSnapshot | None = None,
116116
pipeline_snapshot_file_path: str | None = None,
117117
*,
118-
break_point: AgentBreakpoint | Breakpoint | ToolBreakpoint | None = None,
118+
break_point: Breakpoint | None = None,
119119
) -> None:
120120
super().__init__(message)
121121
self.component = component
@@ -127,7 +127,7 @@ def __init__(
127127
raise ValueError("Either pipeline_snapshot or break_point must be provided.")
128128

129129
@classmethod
130-
def from_triggered_breakpoint(cls, break_point: Breakpoint | ToolBreakpoint) -> "BreakpointException":
130+
def from_triggered_breakpoint(cls, break_point: Breakpoint) -> "BreakpointException":
131131
"""
132132
Create a BreakpointException from a triggered breakpoint.
133133
"""
@@ -137,37 +137,25 @@ def from_triggered_breakpoint(cls, break_point: Breakpoint | ToolBreakpoint) ->
137137
@property
138138
def inputs(self) -> dict[str, Any] | None:
139139
"""
140-
Returns the inputs of the pipeline or agent at the breakpoint.
141-
142-
If an AgentBreakpoint caused this exception, returns the inputs of the agent's internal components.
143-
Otherwise, returns the current inputs of the pipeline.
140+
Returns the current inputs of the pipeline at the breakpoint.
144141
"""
145142
if not self.pipeline_snapshot:
146143
return None
147-
148-
if self.pipeline_snapshot.agent_snapshot:
149-
return self.pipeline_snapshot.agent_snapshot.component_inputs
150144
return self.pipeline_snapshot.pipeline_state.inputs
151145

152146
@property
153147
def results(self) -> dict[str, Any] | None:
154148
"""
155-
Returns the results of the pipeline or agent at the breakpoint.
156-
157-
If an AgentBreakpoint caused this exception, returns the current results of the agent.
158-
Otherwise, returns the current outputs of the pipeline.
149+
Returns the current outputs of the pipeline at the breakpoint.
159150
"""
160151
if not self.pipeline_snapshot:
161152
return None
162-
163-
if self.pipeline_snapshot.agent_snapshot:
164-
return self.pipeline_snapshot.agent_snapshot.component_inputs["tool_invoker"]["serialized_data"]["state"]
165153
return self.pipeline_snapshot.pipeline_state.pipeline_outputs
166154

167155
@property
168-
def break_point(self) -> AgentBreakpoint | Breakpoint | ToolBreakpoint:
156+
def break_point(self) -> Breakpoint:
169157
"""
170-
Returns the Breakpoint or AgentBreakpoint that caused this exception, if available.
158+
Returns the Breakpoint that caused this exception.
171159
172160
If a specific break point was provided during initialization, it is returned.
173161
Otherwise, if the pipeline snapshot contains a break point, that is returned.

haystack/core/pipeline/breakpoint.py

Lines changed: 12 additions & 232 deletions
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,15 @@
77
from collections.abc import Callable
88
from datetime import datetime
99
from pathlib import Path
10-
from typing import TYPE_CHECKING, Any
10+
from typing import Any
1111

1212
from networkx import MultiDiGraph
1313

1414
from haystack import logging
1515
from haystack.core.errors import PipelineInvalidPipelineSnapshotError
1616
from haystack.core.pipeline.utils import _deepcopy_with_exceptions
17-
from haystack.dataclasses import ChatMessage
18-
from haystack.dataclasses.breakpoints import (
19-
AgentBreakpoint,
20-
AgentSnapshot,
21-
Breakpoint,
22-
PipelineSnapshot,
23-
PipelineState,
24-
ToolBreakpoint,
25-
)
17+
from haystack.dataclasses.breakpoints import Breakpoint, PipelineSnapshot, PipelineState
2618
from haystack.utils.base_serialization import _serialize_value_with_schema
27-
from haystack.utils.misc import _get_output_dir
28-
29-
if TYPE_CHECKING:
30-
from haystack.components.agents.agent import _ExecutionContext
31-
from haystack.tools import ToolsType
3219

3320
logger = logging.getLogger(__name__)
3421

@@ -54,34 +41,17 @@ def _is_snapshot_save_enabled() -> bool:
5441
return value in ("true", "1")
5542

5643

57-
def _validate_break_point_against_pipeline(break_point: Breakpoint | AgentBreakpoint, graph: MultiDiGraph) -> None:
44+
def _validate_break_point_against_pipeline(break_point: Breakpoint, graph: MultiDiGraph) -> None:
5845
"""
5946
Validates the breakpoints passed to the pipeline.
6047
6148
Makes sure the breakpoint contains a valid components registered in the pipeline.
6249
63-
:param break_point: a breakpoint to validate, can be Breakpoint or AgentBreakpoint
50+
:param break_point: a breakpoint to validate
6451
"""
65-
66-
# all Breakpoints must refer to a valid component in the pipeline
67-
if isinstance(break_point, Breakpoint) and break_point.component_name not in graph.nodes:
52+
if break_point.component_name not in graph.nodes:
6853
raise ValueError(f"break_point {break_point} is not a registered component in the pipeline")
6954

70-
if isinstance(break_point, AgentBreakpoint):
71-
breakpoint_agent_component = graph.nodes.get(break_point.agent_name)
72-
if not breakpoint_agent_component:
73-
raise ValueError(f"break_point {break_point} is not a registered Agent component in the pipeline")
74-
75-
if isinstance(break_point.break_point, ToolBreakpoint):
76-
instance = breakpoint_agent_component["instance"]
77-
for tool in instance.tools:
78-
if break_point.break_point.tool_name == tool.name:
79-
break
80-
else:
81-
raise ValueError(
82-
f"break_point {break_point.break_point} is not a registered tool in the Agent component"
83-
)
84-
8555

8656
def _validate_pipeline_snapshot_against_pipeline(pipeline_snapshot: PipelineSnapshot, graph: MultiDiGraph) -> None:
8757
"""
@@ -121,11 +91,7 @@ def _validate_pipeline_snapshot_against_pipeline(pipeline_snapshot: PipelineSnap
12191
f"are not part of the current pipeline."
12292
)
12393

124-
if isinstance(pipeline_snapshot.break_point, AgentBreakpoint):
125-
component_name = pipeline_snapshot.break_point.agent_name
126-
else:
127-
component_name = pipeline_snapshot.break_point.component_name
128-
94+
component_name = pipeline_snapshot.break_point.component_name
12995
visit_count = pipeline_snapshot.pipeline_state.component_visits[component_name]
13096

13197
logger.info(
@@ -216,30 +182,18 @@ def _save_pipeline_snapshot(
216182
return None
217183

218184
break_point = pipeline_snapshot.break_point
219-
snapshot_file_path = (
220-
break_point.break_point.snapshot_file_path
221-
if isinstance(break_point, AgentBreakpoint)
222-
else break_point.snapshot_file_path
223-
)
185+
snapshot_file_path = break_point.snapshot_file_path
224186

225187
if snapshot_file_path is None:
226188
return None
227189

228190
dt = pipeline_snapshot.timestamp or datetime.now()
229191
snapshot_dir = Path(snapshot_file_path)
230192

231-
# Generate filename
232-
# We check if the agent_name is provided to differentiate between agent and non-agent breakpoints
233-
if isinstance(break_point, AgentBreakpoint):
234-
agent_name = break_point.agent_name
235-
component_name = break_point.break_point.component_name
236-
else:
237-
component_name = break_point.component_name
238-
agent_name = None
239-
193+
component_name = break_point.component_name
240194
visit_nr = pipeline_snapshot.pipeline_state.component_visits.get(component_name, 0)
241195
timestamp = dt.strftime("%Y_%m_%d_%H_%M_%S")
242-
file_name = f"{agent_name + '_' if agent_name else ''}{component_name}_{visit_nr}_{timestamp}.json"
196+
file_name = f"{component_name}_{visit_nr}_{timestamp}.json"
243197
full_path = snapshot_dir / file_name
244198

245199
try:
@@ -262,7 +216,7 @@ def _create_pipeline_snapshot(
262216
*,
263217
inputs: dict[str, Any],
264218
component_inputs: dict[str, Any],
265-
break_point: AgentBreakpoint | Breakpoint,
219+
break_point: Breakpoint,
266220
component_visits: dict[str, int],
267221
original_input_data: dict[str, Any],
268222
ordered_component_names: list[str],
@@ -274,7 +228,7 @@ def _create_pipeline_snapshot(
274228
275229
:param inputs: The current pipeline snapshot inputs.
276230
:param component_inputs: The inputs to the component that triggered the breakpoint.
277-
:param break_point: The breakpoint that triggered the snapshot, can be AgentBreakpoint or Breakpoint.
231+
:param break_point: The breakpoint that triggered the snapshot.
278232
:param component_visits: The visit count of the component that triggered the breakpoint.
279233
:param original_input_data: The original input data.
280234
:param ordered_component_names: The ordered component names.
@@ -283,10 +237,7 @@ def _create_pipeline_snapshot(
283237
:returns:
284238
A PipelineSnapshot containing the state of the pipeline at the point of the breakpoint.
285239
"""
286-
if isinstance(break_point, AgentBreakpoint):
287-
component_name = break_point.agent_name
288-
else:
289-
component_name = break_point.component_name
240+
component_name = break_point.component_name
290241

291242
transformed_original_input_data = _transform_json_structure(original_input_data)
292243
transformed_inputs = _transform_json_structure({**inputs, component_name: component_inputs})
@@ -343,31 +294,6 @@ def _transform_json_structure(data: dict[str, Any] | list[Any] | Any) -> Any:
343294
return data
344295

345296

346-
def _create_agent_snapshot(
347-
*, component_visits: dict[str, int], agent_breakpoint: AgentBreakpoint, component_inputs: dict[str, Any]
348-
) -> AgentSnapshot:
349-
"""
350-
Create a snapshot of the agent's state.
351-
352-
:param component_visits: The visit counts for the agent's components.
353-
:param agent_breakpoint: AgentBreakpoint object containing breakpoints
354-
:return: An AgentSnapshot containing the agent's state and component visits.
355-
"""
356-
serialized_chat_generator = _serialize_agent_component_inputs(
357-
component_name="chat_generator", component_inputs=component_inputs["chat_generator"]
358-
)
359-
serialized_tool_invoker = _serialize_agent_component_inputs(
360-
component_name="tool_invoker", component_inputs=component_inputs["tool_invoker"]
361-
)
362-
363-
return AgentSnapshot(
364-
component_inputs={"chat_generator": serialized_chat_generator, "tool_invoker": serialized_tool_invoker},
365-
component_visits=component_visits,
366-
break_point=agent_breakpoint,
367-
timestamp=datetime.now(),
368-
)
369-
370-
371297
def _serialize_with_field_fallback(payload: Any, *, description: str) -> dict[str, Any]:
372298
"""
373299
Serialize a payload and, on failure, retry field-by-field to preserve resumable fields.
@@ -417,149 +343,3 @@ def _serialize_with_field_fallback(payload: Any, *, description: str) -> dict[st
417343
"serialization_schema": {"type": "object", "properties": serialized_properties},
418344
"serialized_data": serialized_data,
419345
}
420-
421-
422-
def _serialize_agent_component_inputs(component_name: str, component_inputs: dict[str, Any]) -> dict[str, Any]:
423-
"""
424-
Serialize agent component inputs while preserving resumable fields whenever possible.
425-
426-
Thin wrapper around :func:`_serialize_with_field_fallback` that supplies an agent-specific label
427-
for the warning messages.
428-
429-
:param component_name: Name of the agent sub-component (e.g. ``chat_generator`` or ``tool_invoker``).
430-
:param component_inputs: Runtime inputs for that sub-component.
431-
:returns: A serialized payload that is always a structurally valid ``{"serialization_schema",
432-
"serialized_data"}`` pair. When every field fails to serialize, an empty-but-valid object
433-
payload is returned so that ``_deserialize_value_with_schema`` can still load it (for example
434-
when resuming from a ``ToolBreakpoint`` where the sub-component's inputs are not strictly required).
435-
"""
436-
return _serialize_with_field_fallback(component_inputs, description=f"the agent's {component_name} inputs")
437-
438-
439-
def _validate_tool_breakpoint_is_valid(agent_breakpoint: AgentBreakpoint, tools: "ToolsType") -> None:
440-
"""
441-
Validates the AgentBreakpoint passed to the agent.
442-
443-
Validates that the tool name in ToolBreakpoints correspond to a tool available in the agent.
444-
445-
:param agent_breakpoint: AgentBreakpoint object containing breakpoints for the agent components.
446-
:param tools: A list of Tool and/or Toolset objects, or a Toolset that the agent can use.
447-
:raises ValueError: If any tool name in ToolBreakpoints is not available in the agent's tools.
448-
"""
449-
from haystack.tools.utils import flatten_tools_or_toolsets # avoid circular import
450-
451-
available_tool_names = {tool.name for tool in flatten_tools_or_toolsets(tools)}
452-
tool_breakpoint = agent_breakpoint.break_point
453-
# Assert added for mypy to pass, but this is already checked before this function is called
454-
assert isinstance(tool_breakpoint, ToolBreakpoint)
455-
if tool_breakpoint.tool_name and tool_breakpoint.tool_name not in available_tool_names:
456-
raise ValueError(f"Tool '{tool_breakpoint.tool_name}' is not available in the agent's tools")
457-
458-
459-
def _create_pipeline_snapshot_from_chat_generator(
460-
*, execution_context: "_ExecutionContext", agent_name: str | None = None, break_point: AgentBreakpoint | None = None
461-
) -> PipelineSnapshot:
462-
"""
463-
Create a pipeline snapshot when a chat generator breakpoint is raised or an exception during execution occurs.
464-
465-
:param execution_context: The current execution context of the agent.
466-
:param agent_name: The name of the agent component if present in a pipeline.
467-
:param break_point: An optional AgentBreakpoint object. If provided, it will be used instead of creating a new one.
468-
A scenario where a new breakpoint is created is when an exception occurs during chat generation and we want to
469-
capture the state at that point.
470-
:returns:
471-
A PipelineSnapshot containing the state of the pipeline and agent at the point of the breakpoint or exception.
472-
"""
473-
if break_point is None:
474-
agent_breakpoint = AgentBreakpoint(
475-
agent_name=agent_name or "agent",
476-
break_point=Breakpoint(
477-
component_name="chat_generator",
478-
visit_count=execution_context.component_visits["chat_generator"],
479-
snapshot_file_path=_get_output_dir("pipeline_snapshot"),
480-
),
481-
)
482-
else:
483-
agent_breakpoint = break_point
484-
485-
agent_snapshot = _create_agent_snapshot(
486-
component_visits=execution_context.component_visits,
487-
agent_breakpoint=agent_breakpoint,
488-
component_inputs={
489-
"chat_generator": {
490-
"messages": execution_context.state.data["messages"],
491-
**execution_context.chat_generator_inputs,
492-
},
493-
"tool_invoker": {"messages": [], "state": execution_context.state, **execution_context.tool_invoker_inputs},
494-
},
495-
)
496-
497-
return PipelineSnapshot._from_agent_snapshot(agent_snapshot=agent_snapshot)
498-
499-
500-
def _create_pipeline_snapshot_from_tool_invoker(
501-
*,
502-
execution_context: "_ExecutionContext",
503-
tool_name: str | None = None,
504-
agent_name: str | None = None,
505-
break_point: AgentBreakpoint | None = None,
506-
) -> PipelineSnapshot:
507-
"""
508-
Create a pipeline snapshot when a tool invoker breakpoint is raised or an exception during execution occurs.
509-
510-
:param execution_context: The current execution context of the agent.
511-
:param tool_name: The name of the tool that triggered the breakpoint, if available.
512-
:param agent_name: The name of the agent component if present in a pipeline.
513-
:param break_point: An optional AgentBreakpoint object. If provided, it will be used instead of creating a new one.
514-
A scenario where a new breakpoint is created is when an exception occurs during tool execution and we want to
515-
capture the state at that point.
516-
:returns:
517-
A PipelineSnapshot containing the state of the pipeline and agent at the point of the breakpoint or exception.
518-
"""
519-
if break_point is None:
520-
agent_breakpoint = AgentBreakpoint(
521-
agent_name=agent_name or "agent",
522-
break_point=ToolBreakpoint(
523-
component_name="tool_invoker",
524-
visit_count=execution_context.component_visits["tool_invoker"],
525-
tool_name=tool_name,
526-
snapshot_file_path=_get_output_dir("pipeline_snapshot"),
527-
),
528-
)
529-
else:
530-
agent_breakpoint = break_point
531-
532-
messages = execution_context.state.data["messages"]
533-
agent_snapshot = _create_agent_snapshot(
534-
component_visits=execution_context.component_visits,
535-
agent_breakpoint=agent_breakpoint,
536-
component_inputs={
537-
"chat_generator": {"messages": messages[:-1], **execution_context.chat_generator_inputs},
538-
"tool_invoker": {
539-
"messages": messages[-1:], # tool invoker consumes last msg from the chat_generator, contains tool call
540-
"state": execution_context.state,
541-
**execution_context.tool_invoker_inputs,
542-
},
543-
},
544-
)
545-
546-
# Create an empty pipeline snapshot
547-
return PipelineSnapshot._from_agent_snapshot(agent_snapshot=agent_snapshot)
548-
549-
550-
def _should_trigger_tool_invoker_breakpoint(break_point: ToolBreakpoint, llm_messages: list[ChatMessage]) -> bool:
551-
"""
552-
Determine if a tool invoker breakpoint should be triggered based on the provided ToolBreakpoint and LLM messages.
553-
554-
:param break_point: The ToolBreakpoint to check against.
555-
:param llm_messages: A list of ChatMessage objects representing the LLM messages.
556-
:returns:
557-
True if the breakpoint should be triggered, False otherwise.
558-
"""
559-
# Check if we should break for this specific tool or all tools
560-
if break_point.tool_name is None:
561-
# Break for any tool call
562-
return any(msg.tool_call for msg in llm_messages)
563-
564-
# Break only for the specific tool
565-
return any(tc.tool_name == break_point.tool_name for msg in llm_messages for tc in msg.tool_calls or [])

0 commit comments

Comments
 (0)