From 9a638aeb3b2ba676b29d975efb4dcc62fefafa9c Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Thu, 5 Jun 2025 19:10:16 +0200 Subject: [PATCH 01/30] improve breakpoints --- .../core/pipeline/pipeline.py | 72 ++++++++----------- 1 file changed, 29 insertions(+), 43 deletions(-) diff --git a/haystack_experimental/core/pipeline/pipeline.py b/haystack_experimental/core/pipeline/pipeline.py index 8b254366..268f9a1a 100644 --- a/haystack_experimental/core/pipeline/pipeline.py +++ b/haystack_experimental/core/pipeline/pipeline.py @@ -3,16 +3,17 @@ # SPDX-License-Identifier: Apache-2.0 import json -from copy import deepcopy +from copy import copy, deepcopy from datetime import datetime from pathlib import Path from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Tuple, Union, cast from haystack import Answer, Document, ExtractedAnswer, logging, tracing -from haystack.components.joiners import BranchJoiner, DocumentJoiner +from haystack.components.joiners import DocumentJoiner from haystack.core.component import Component from haystack.dataclasses import ChatMessage, GeneratedAnswer, SparseEmbedding from haystack.telemetry import pipeline_running +from haystack.utils import _serialize_value_with_schema from haystack_experimental.core.errors import ( PipelineBreakpointException, @@ -289,10 +290,21 @@ def run( # noqa: PLR0915, PLR0912 component_inputs[key] = deserialize_component_input(value) if validated_breakpoints and not self.resume_state: - state_inputs_serialised = remove_unserializable_data(deepcopy(inputs)) + state_inputs_serialised = deepcopy(inputs) # inject the component_inputs into the state_inputs so we can this component init params in # the JSON state - state_inputs_serialised[component_name] = remove_unserializable_data(deepcopy(component_inputs)) + state_inputs_serialised[component_name] = deepcopy(component_inputs) + + # TODO: Adding the old approach until we finalize if we + # want to use serialization with schema or without + + # We use copy instead of deepcopy to avoid issues with unpickleable objects like RLock + params = copy(component["instance"].__dict__) + + # this is needed as the template param is stored as _template_string in the component's __dict__ + # if "_template_string" in params: + # params["template"] = params["_template_string"] + state_inputs_serialised[component_name]["init_parameters"] = params Pipeline._check_breakpoints( breakpoints=validated_breakpoints, @@ -301,22 +313,13 @@ def run( # noqa: PLR0915, PLR0912 inputs=state_inputs_serialised, debug_path=self.debug_path, original_input_data=data, - ordered_component_names=self.ordered_component_names + ordered_component_names=self.ordered_component_names, ) - # the _consume_component_inputs() when applied to the DocumentJoiner inputs wraps 'documents' in an - # extra list, so there's a 3 level deep list, we need to flatten it to 2 levels only - instance: Component = component["instance"] - if self.resume_state and isinstance(instance, DocumentJoiner): # noqa: SIM102 - if isinstance(component_inputs["documents"], list): # noqa: SIM102 - if isinstance(component_inputs["documents"][0], list): # noqa: SIM102 - if isinstance(component_inputs["documents"][0][0], list): # noqa: SIM102 - component_inputs["documents"] = component_inputs["documents"][0] - component_outputs = self._run_component( component_name=component_name, component=component, - inputs=component_inputs, # the inputs to the current component + inputs=component_inputs, # the inputs to the current component component_visits=component_visits, parent_span=span, ) @@ -520,6 +523,15 @@ def save_state( The saved state dictionary """ dt = datetime.now() + + # we store a component init_parameters together with the breakpoint in the saved state + # this is helpful for debugging or manually updating the state + for value in inputs.values(): + if "init_parameters" in value.keys(): + init_params = value.pop("init_parameters") + for k, v in value.items(): + if k in init_params.keys() and v is None: + value[k] = serialize_component_input(init_params[k]) state = { "input_data": serialize_component_input(original_input_data), # original input data "timestamp": dt.isoformat(), @@ -563,7 +575,7 @@ def _check_breakpoints( inputs: Dict[str, Any], debug_path: Optional[Union[str, Path]] = None, original_input_data: Optional[Dict[str, Any]] = None, - ordered_component_names: Optional[List[str]] = None + ordered_component_names: Optional[List[str]] = None, ): """ Check if the `component_name` is in the breakpoints and if it should break. @@ -588,12 +600,11 @@ def _check_breakpoints( component_visits=component_visits, debug_path=debug_path, original_input_data=original_input_data, - ordered_component_names=ordered_component_names + ordered_component_names=ordered_component_names, ) raise PipelineBreakpointException(msg, component=component_name, state=state) - def deserialize_component_input(value): # noqa: PLR0911 """ Tries to deserialize any type of input that can be passed to as input to a pipeline component. @@ -668,33 +679,10 @@ def transform_json_structure(data: Union[Dict[str, Any], List[Any], Any]) -> Any return data -def remove_unserializable_data(value: Any) -> Any: - """ - Removes certain unserializable data which is not needed for the pipeline state. - """ - - if isinstance(value, ChatMessage): # noqa: SIM102 - if "usage" in value.meta: # noqa: SIM102 - value.meta["usage"].pop("completion_tokens_details", None) - value.meta["usage"].pop("prompt_tokens_details", None) - - if isinstance(value, GeneratedAnswer): # noqa: SIM102 - if value.meta and "usage" in value.meta: # noqa: SIM102 - value.meta.pop("usage", None) - - # all_messages contains a list of unserialized ChatMessages - # TODO: we should find a better way to handle this - if value.meta and "all_messages" in value.meta: - value.meta.pop("all_messages", None) - - return value - - def serialize_component_input(value: Any) -> Any: """ Serializes, so it can be saved to a file, any type of input to a pipeline component. """ - value = remove_unserializable_data(value) value = transform_json_structure(value) if hasattr(value, "to_dict") and callable(getattr(value, "to_dict")): serialized_value = value.to_dict() @@ -717,5 +705,3 @@ def serialize_component_input(value: Any) -> Any: return [serialize_component_input(item) for item in value] return value - - From 189a0e5bc14fa5f3a975b1de0c26c8bb7c50e3b9 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Thu, 5 Jun 2025 19:17:27 +0200 Subject: [PATCH 02/30] updates --- haystack_experimental/core/pipeline/pipeline.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/haystack_experimental/core/pipeline/pipeline.py b/haystack_experimental/core/pipeline/pipeline.py index 268f9a1a..cfa20134 100644 --- a/haystack_experimental/core/pipeline/pipeline.py +++ b/haystack_experimental/core/pipeline/pipeline.py @@ -9,11 +9,9 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Tuple, Union, cast from haystack import Answer, Document, ExtractedAnswer, logging, tracing -from haystack.components.joiners import DocumentJoiner from haystack.core.component import Component from haystack.dataclasses import ChatMessage, GeneratedAnswer, SparseEmbedding from haystack.telemetry import pipeline_running -from haystack.utils import _serialize_value_with_schema from haystack_experimental.core.errors import ( PipelineBreakpointException, @@ -300,10 +298,6 @@ def run( # noqa: PLR0915, PLR0912 # We use copy instead of deepcopy to avoid issues with unpickleable objects like RLock params = copy(component["instance"].__dict__) - - # this is needed as the template param is stored as _template_string in the component's __dict__ - # if "_template_string" in params: - # params["template"] = params["_template_string"] state_inputs_serialised[component_name]["init_parameters"] = params Pipeline._check_breakpoints( From e79a5beeafee62ec2d5a2db52b23dc9174e8748b Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Fri, 6 Jun 2025 15:59:02 +0200 Subject: [PATCH 03/30] Fixes --- .../core/pipeline/pipeline.py | 126 +++++++++++++----- pyproject.toml | 4 +- test/core/pipeline/test_pipeline.py | 32 ++--- 3 files changed, 113 insertions(+), 49 deletions(-) diff --git a/haystack_experimental/core/pipeline/pipeline.py b/haystack_experimental/core/pipeline/pipeline.py index cfa20134..6a4dc2c0 100644 --- a/haystack_experimental/core/pipeline/pipeline.py +++ b/haystack_experimental/core/pipeline/pipeline.py @@ -10,6 +10,7 @@ from haystack import Answer, Document, ExtractedAnswer, logging, tracing from haystack.core.component import Component +from haystack.core.pipeline.component_checks import is_socket_lazy_variadic from haystack.dataclasses import ChatMessage, GeneratedAnswer, SparseEmbedding from haystack.telemetry import pipeline_running @@ -285,7 +286,7 @@ def run( # noqa: PLR0915, PLR0912 # this check will prevent other component_inputs generated at runtime from being deserialized if self.resume_state and component_name in self.resume_state["pipeline_state"]["inputs"].keys(): for key, value in component_inputs.items(): - component_inputs[key] = deserialize_component_input(value) + component_inputs[key] = _deserialize_component_input(value) if validated_breakpoints and not self.resume_state: state_inputs_serialised = deepcopy(inputs) @@ -293,13 +294,18 @@ def run( # noqa: PLR0915, PLR0912 # the JSON state state_inputs_serialised[component_name] = deepcopy(component_inputs) - # TODO: Adding the old approach until we finalize if we - # want to use serialization with schema or without - + # we use dict instead of to_dict() because it strips away class types # We use copy instead of deepcopy to avoid issues with unpickleable objects like RLock params = copy(component["instance"].__dict__) state_inputs_serialised[component_name]["init_parameters"] = params + if "_template_string" in params: + params["template"] = params["_template_string"] + params.pop("_template_string") + + print("PARAMS") + print(params) + Pipeline._check_breakpoints( breakpoints=validated_breakpoints, component_name=component_name, @@ -309,6 +315,13 @@ def run( # noqa: PLR0915, PLR0912 original_input_data=data, ordered_component_names=self.ordered_component_names, ) + # the _consume_component_inputs() when applied to the DocumentJoiner inputs wraps 'documents' in an + # extra list, so there's a 3 level deep list, we need to flatten it to 2 levels only + + if self.resume_state and self.resume_state["breakpoint"]["component"] == component_name: + for socket_name, socket in component["input_sockets"].items(): + if is_socket_lazy_variadic(socket): + component_inputs[socket_name] = component_inputs[socket_name][0] component_outputs = self._run_component( component_name=component_name, @@ -518,29 +531,31 @@ def save_state( """ dt = datetime.now() - # we store a component init_parameters together with the breakpoint in the saved state - # this is helpful for debugging or manually updating the state + # we store a input params passed during init() in the saved state + # this is helpful for retaining the state of the component and manual debugging for value in inputs.values(): - if "init_parameters" in value.keys(): - init_params = value.pop("init_parameters") - for k, v in value.items(): - if k in init_params.keys() and v is None: - value[k] = serialize_component_input(init_params[k]) + if "init_parameters" not in value: + continue + init_params = value.pop("init_parameters") + for k, v in value.items(): + if k in init_params and not v: + value[k] = _serialize_component_input(init_params[k]) + state = { - "input_data": serialize_component_input(original_input_data), # original input data + "input_data": _serialize_component_input(original_input_data), # original input data "timestamp": dt.isoformat(), "breakpoint": {"component": component_name, "visits": component_visits[component_name]}, "pipeline_state": { - "inputs": serialize_component_input(inputs), # current pipeline state inputs + "inputs": _serialize_component_input(inputs), # current pipeline state inputs "component_visits": component_visits, "ordered_component_names": ordered_component_names, }, } + if not debug_path: return state - if isinstance(debug_path, str): - debug_path = Path(debug_path) + debug_path = Path(debug_path) if isinstance(debug_path, str) else debug_path if not isinstance(debug_path, Path): raise ValueError("Debug path must be a string or a Path object.") debug_path.mkdir(exist_ok=True) @@ -570,7 +585,7 @@ def _check_breakpoints( debug_path: Optional[Union[str, Path]] = None, original_input_data: Optional[Dict[str, Any]] = None, ordered_component_names: Optional[List[str]] = None, - ): + ) -> None: """ Check if the `component_name` is in the breakpoints and if it should break. @@ -578,6 +593,10 @@ def _check_breakpoints( :param component_name: Name of the component to check. :param component_visits: The number of times the component has been visited. :param inputs: The inputs to the pipeline. + :param debug_path: The file path where debug state is saved. + :param original_input_data: The original input data to the pipeline. + :param ordered_component_names: The ordered component names in the pipeline. + :raises PipelineBreakpointException: When a breakpoint is triggered, with component state information. """ matching_breakpoints = [bp for bp in breakpoints if bp[0] == component_name] @@ -598,8 +617,48 @@ def _check_breakpoints( ) raise PipelineBreakpointException(msg, component=component_name, state=state) + @staticmethod + def _check_breakpoints_new( + breakpoints: Set[Tuple[str, int]], + component_name: str, + component_visits: Dict[str, int], + inputs: Dict[str, Any], + debug_path: Optional[Union[str, Path]] = None, + original_input_data: Optional[Dict[str, Any]] = None, + ordered_component_names: Optional[List[str]] = None, + ) -> None: + """ + Check if the `component_name` is in the breakpoints and if it should break. -def deserialize_component_input(value): # noqa: PLR0911 + :param breakpoints: Set of tuples of component names and visit counts at which the pipeline should stop. + :param component_name: Name of the component to check. + :param component_visits: The number of times the component has been visited. + :param inputs: The inputs to the pipeline. + :param debug_path: The file path where debug state is saved. + :param original_input_data: The original input data to the pipeline. + :param ordered_component_names: The ordered component names in the pipeline. + + :raises PipelineBreakpointException: When a breakpoint is triggered, with component state information. + """ + for component, visit_count in breakpoints: + if component != component_name: + continue + + if visit_count == component_visits[component_name]: + msg = f"Breaking at component {component_name} visit count {component_visits[component_name]}" + logger.info(msg) + state = Pipeline.save_state( + inputs=inputs, + component_name=str(component_name), + component_visits=component_visits, + debug_path=debug_path, + original_input_data=original_input_data, + ordered_component_names=ordered_component_names, + ) + raise PipelineBreakpointException(msg, component=component_name, state=state) + + +def _deserialize_component_input(value: Any) -> Any: # noqa: PLR0911 """ Tries to deserialize any type of input that can be passed to as input to a pipeline component. @@ -617,10 +676,10 @@ def deserialize_component_input(value): # noqa: PLR0911 if isinstance(value, list): # list of lists are called recursively if all(isinstance(i, list) for i in value): - return [deserialize_component_input(i) for i in value] + return [_deserialize_component_input(i) for i in value] # list of dicts are called recursively if all(isinstance(i, dict) for i in value): - return [deserialize_component_input(i) for i in value] + return [_deserialize_component_input(i) for i in value] # Define the mapping of types to their deserialization functions _type_deserializers = { @@ -640,44 +699,49 @@ def deserialize_component_input(value): # noqa: PLR0911 return _type_deserializers[type_name](value) # If not a known type, recursively deserialize each item in the dictionary - return {k: deserialize_component_input(v) for k, v in value.items()} + return {k: _deserialize_component_input(v) for k, v in value.items()} return value -def transform_json_structure(data: Union[Dict[str, Any], List[Any], Any]) -> Any: +def _transform_json_structure(data: Union[Dict[str, Any], List[Any], Any]) -> Any: """ Transforms a JSON structure by removing the 'sender' key and moving the 'value' to the top level. For example: "key": [{"sender": null, "value": "some value"}] -> "key": "some value" + + :param data: The JSON structure to transform. + :returns: The transformed structure. """ if isinstance(data, dict): # If this dict has both 'sender' and 'value', return just the value if "value" in data and "sender" in data: return data["value"] # Otherwise, recursively process each key-value pair - return {k: transform_json_structure(v) for k, v in data.items()} + return {k: _transform_json_structure(v) for k, v in data.items()} - elif isinstance(data, list): + if isinstance(data, list): # First, transform each item in the list. - transformed = [transform_json_structure(item) for item in data] + transformed = [_transform_json_structure(item) for item in data] # If the original list has exactly one element and that element was a dict # with 'sender' and 'value', then unwrap the list. if len(data) == 1 and isinstance(data[0], dict) and "value" in data[0] and "sender" in data[0]: return transformed[0] return transformed - else: - # For other data types, just return the value as is. - return data + # For other data types, just return the value as is. + return data -def serialize_component_input(value: Any) -> Any: +def _serialize_component_input(value: Any) -> Any: """ Serializes, so it can be saved to a file, any type of input to a pipeline component. + + :param value: The value to serialize. + :returns: The serialized value that can be saved to a file. """ - value = transform_json_structure(value) + value = _transform_json_structure(value) if hasattr(value, "to_dict") and callable(getattr(value, "to_dict")): serialized_value = value.to_dict() serialized_value["_type"] = value.__class__.__name__ @@ -692,10 +756,10 @@ def serialize_component_input(value: Any) -> Any: # recursively serialize all inputs in a dict elif isinstance(value, dict): - return {k: serialize_component_input(v) for k, v in value.items()} + return {k: _serialize_component_input(v) for k, v in value.items()} # recursively serialize all inputs in lists or tuples elif isinstance(value, list): - return [serialize_component_input(item) for item in value] + return [_serialize_component_input(item) for item in value] return value diff --git a/pyproject.toml b/pyproject.toml index f2fbb6f9..d760786d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,12 +22,12 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.13", "Topic :: Scientific/Engineering :: Artificial Intelligence", ] dependencies = [ - "haystack-ai", + "haystack-ai @ git+https://github.com/deepset-ai/haystack.git@main", "filetype", # for mime type detection in ImageContent ] diff --git a/test/core/pipeline/test_pipeline.py b/test/core/pipeline/test_pipeline.py index 9a30032e..91cfb6b7 100644 --- a/test/core/pipeline/test_pipeline.py +++ b/test/core/pipeline/test_pipeline.py @@ -11,9 +11,9 @@ from haystack_experimental.core.errors import PipelineRuntimeError from haystack_experimental.core.pipeline.pipeline import ( Pipeline, - transform_json_structure, - serialize_component_input, - deserialize_component_input + _transform_json_structure, + _serialize_component_input, + _deserialize_component_input ) @@ -156,14 +156,14 @@ def test_transform_json_structure_unwraps_sender_value(): "key3": "direct value" } - result = transform_json_structure(data) + result = _transform_json_structure(data) assert result == { "key1": "some value", "key2": 42, "key3": "direct value" } - + def test_transform_json_structure_handles_nested_structures(): data = { "key1": [{"sender": None, "value": "value1"}], @@ -177,7 +177,7 @@ def test_transform_json_structure_handles_nested_structures(): ] } - result = transform_json_structure(data) + result = _transform_json_structure(data) assert result == { "key1": "value1", @@ -190,7 +190,7 @@ def test_transform_json_structure_handles_nested_structures(): "value5" ] } - + def test_serialize_component_input_handles_objects_with_to_dict(): class TestObject: def __init__(self, value): @@ -200,7 +200,7 @@ def to_dict(self): return {"value": self.value} obj = TestObject("test") - result = serialize_component_input(obj) + result = _serialize_component_input(obj) assert result == { "_type": "TestObject", "value": "test" @@ -212,7 +212,7 @@ def __init__(self, value): self.value = value obj = TestObject("test") - result = serialize_component_input(obj) + result = _serialize_component_input(obj) assert result == { "_type": "TestObject", "attributes": {"value": "test"} @@ -232,7 +232,7 @@ def to_dict(self): "key2": [obj, "string"], "key3": {"nested": obj} } - result = serialize_component_input(data) + result = _serialize_component_input(data) assert result["key1"]["_type"] == "TestObject" assert result["key2"][0]["_type"] == "TestObject" @@ -247,7 +247,7 @@ def test_deserialize_component_input_handles_primitive_types(): "bool": True, "none": None } - result = deserialize_component_input(data) + result = _deserialize_component_input(data) assert result == data def test_deserialize_component_input_handles_lists(): @@ -255,7 +255,7 @@ def test_deserialize_component_input_handles_lists(): "primitive_list": [1, 2, 3], "mixed_list": [1, "string", True] } - result = deserialize_component_input(data) + result = _deserialize_component_input(data) assert result == data def test_deserialize_component_input_handles_dicts(): @@ -263,7 +263,7 @@ def test_deserialize_component_input_handles_dicts(): "key1": "value1", "key2": {"nested": "value2"} } - result = deserialize_component_input(data) + result = _deserialize_component_input(data) assert result == data def test_deserialize_component_input_handles_nested_lists(): @@ -273,7 +273,7 @@ def test_deserialize_component_input_handles_nested_lists(): "mixed_nested": [[1, "string"], [True, 3.14]] } - result = deserialize_component_input(data) + result = _deserialize_component_input(data) assert result == data @@ -288,7 +288,7 @@ def test_deserialize_component_input_handles_nested_dicts(): } } - result = deserialize_component_input(data) + result = _deserialize_component_input(data) assert result == data @@ -300,7 +300,7 @@ def test_deserialize_component_input_handles_empty_structures(): "nested_empty": {"empty": []} } - result = deserialize_component_input(data) + result = _deserialize_component_input(data) assert result == data From 87e3cf77ea34ef561ae9865738088d8c26459582 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Fri, 6 Jun 2025 16:13:28 +0200 Subject: [PATCH 04/30] Remove extra methods --- .../core/pipeline/pipeline.py | 44 ------------------- 1 file changed, 44 deletions(-) diff --git a/haystack_experimental/core/pipeline/pipeline.py b/haystack_experimental/core/pipeline/pipeline.py index 6a4dc2c0..14045472 100644 --- a/haystack_experimental/core/pipeline/pipeline.py +++ b/haystack_experimental/core/pipeline/pipeline.py @@ -303,9 +303,6 @@ def run( # noqa: PLR0915, PLR0912 params["template"] = params["_template_string"] params.pop("_template_string") - print("PARAMS") - print(params) - Pipeline._check_breakpoints( breakpoints=validated_breakpoints, component_name=component_name, @@ -597,47 +594,6 @@ def _check_breakpoints( :param original_input_data: The original input data to the pipeline. :param ordered_component_names: The ordered component names in the pipeline. - :raises PipelineBreakpointException: When a breakpoint is triggered, with component state information. - """ - matching_breakpoints = [bp for bp in breakpoints if bp[0] == component_name] - - for bp in matching_breakpoints: - visit_count = bp[1] - # break only if the visit count is the same - if visit_count == component_visits[component_name]: - msg = f"Breaking at component {component_name} visit count {component_visits[component_name]}" - logger.info(msg) - state = Pipeline.save_state( - inputs=inputs, - component_name=str(component_name), - component_visits=component_visits, - debug_path=debug_path, - original_input_data=original_input_data, - ordered_component_names=ordered_component_names, - ) - raise PipelineBreakpointException(msg, component=component_name, state=state) - - @staticmethod - def _check_breakpoints_new( - breakpoints: Set[Tuple[str, int]], - component_name: str, - component_visits: Dict[str, int], - inputs: Dict[str, Any], - debug_path: Optional[Union[str, Path]] = None, - original_input_data: Optional[Dict[str, Any]] = None, - ordered_component_names: Optional[List[str]] = None, - ) -> None: - """ - Check if the `component_name` is in the breakpoints and if it should break. - - :param breakpoints: Set of tuples of component names and visit counts at which the pipeline should stop. - :param component_name: Name of the component to check. - :param component_visits: The number of times the component has been visited. - :param inputs: The inputs to the pipeline. - :param debug_path: The file path where debug state is saved. - :param original_input_data: The original input data to the pipeline. - :param ordered_component_names: The ordered component names in the pipeline. - :raises PipelineBreakpointException: When a breakpoint is triggered, with component state information. """ for component, visit_count in breakpoints: From cc098a7295bb1ca2712490049826fe4ac0d7904c Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Tue, 10 Jun 2025 14:53:35 +0200 Subject: [PATCH 05/30] Update init params --- haystack_experimental/core/pipeline/pipeline.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/haystack_experimental/core/pipeline/pipeline.py b/haystack_experimental/core/pipeline/pipeline.py index 14045472..0902d810 100644 --- a/haystack_experimental/core/pipeline/pipeline.py +++ b/haystack_experimental/core/pipeline/pipeline.py @@ -13,6 +13,7 @@ from haystack.core.pipeline.component_checks import is_socket_lazy_variadic from haystack.dataclasses import ChatMessage, GeneratedAnswer, SparseEmbedding from haystack.telemetry import pipeline_running +from haystack.utils.base_serialization import _deserialize_value_with_schema, _serialize_value_with_schema from haystack_experimental.core.errors import ( PipelineBreakpointException, @@ -273,7 +274,6 @@ def run( # noqa: PLR0915, PLR0912 component_name, component_visits[component_name] ) - # this breaks the pipeline breakpoints component_inputs = self._consume_component_inputs( component_name=component_name, component=component, inputs=inputs ) @@ -294,14 +294,15 @@ def run( # noqa: PLR0915, PLR0912 # the JSON state state_inputs_serialised[component_name] = deepcopy(component_inputs) - # we use dict instead of to_dict() because it strips away class types - # We use copy instead of deepcopy to avoid issues with unpickleable objects like RLock - params = copy(component["instance"].__dict__) - state_inputs_serialised[component_name]["init_parameters"] = params + # the init params are stored for the component with breakpoint + # this is helpful for retaining the state of the component and manual debugging + init_params = {} + # we use dict instead of to_dict() because it strips away class types of init params + for key, value in component["instance"].__dict__.items(): + if not key.startswith("__"): + init_params[key] = value - if "_template_string" in params: - params["template"] = params["_template_string"] - params.pop("_template_string") + state_inputs_serialised[component_name]["init_parameters"] = init_params Pipeline._check_breakpoints( breakpoints=validated_breakpoints, From b01ffaf3e97c0f61f20b0db585982ab4864d3f0e Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Tue, 10 Jun 2025 15:24:08 +0200 Subject: [PATCH 06/30] type fix --- haystack_experimental/core/pipeline/pipeline.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/haystack_experimental/core/pipeline/pipeline.py b/haystack_experimental/core/pipeline/pipeline.py index ca82b961..af58d77b 100644 --- a/haystack_experimental/core/pipeline/pipeline.py +++ b/haystack_experimental/core/pipeline/pipeline.py @@ -292,19 +292,16 @@ def run( # noqa: PLR0915, PLR0912 if validated_breakpoints and not self.resume_state: state_inputs_serialised = deepcopy(inputs) - # inject the component_inputs into the state_inputs so we can this component init params in - # the JSON state + # we store the init params for the component with breakpoint in debug state + # this is helpful for retaining the state of the component and manual debugging state_inputs_serialised[component_name] = deepcopy(component_inputs) - # the init params are stored for the component with breakpoint - # this is helpful for retaining the state of the component and manual debugging - init_params = {} # we use dict instead of to_dict() because it strips away class types of init params - for key, value in component["instance"].__dict__.items(): - if not key.startswith("__"): - init_params[key] = value - - state_inputs_serialised[component_name]["init_parameters"] = init_params + state_inputs_serialised[component_name]["init_parameters"] = { + key: value + for key, value in component["instance"].__dict__.items() + if not key.startswith("__") + } # type: ignore[assignment] Pipeline._check_breakpoints( breakpoints=validated_breakpoints, From 684b584c7d9b372245a0de98343868880dec4955 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Tue, 10 Jun 2025 15:39:28 +0200 Subject: [PATCH 07/30] Add GeneratedAnswer to experimental --- .../core/pipeline/pipeline.py | 7 +- haystack_experimental/dataclasses/__init__.py | 2 + haystack_experimental/dataclasses/answer.py | 134 ++++++++++++++++++ pyproject.toml | 4 +- 4 files changed, 142 insertions(+), 5 deletions(-) create mode 100644 haystack_experimental/dataclasses/answer.py diff --git a/haystack_experimental/core/pipeline/pipeline.py b/haystack_experimental/core/pipeline/pipeline.py index af58d77b..5b0bdf0f 100644 --- a/haystack_experimental/core/pipeline/pipeline.py +++ b/haystack_experimental/core/pipeline/pipeline.py @@ -14,7 +14,7 @@ from haystack import Answer, Document, ExtractedAnswer, logging, tracing from haystack.core.component import Component from haystack.core.pipeline.component_checks import is_socket_lazy_variadic -from haystack.dataclasses import ChatMessage, GeneratedAnswer, SparseEmbedding +from haystack.dataclasses import ChatMessage, SparseEmbedding from haystack.telemetry import pipeline_running from haystack_experimental.core.errors import ( @@ -29,6 +29,7 @@ ComponentPriority, PipelineBase, ) +from haystack_experimental.dataclasses import GeneratedAnswer logger = logging.getLogger(__name__) @@ -312,9 +313,9 @@ def run( # noqa: PLR0915, PLR0912 original_input_data=data, ordered_component_names=self.ordered_component_names, ) - # the _consume_component_inputs() when applied to the DocumentJoiner inputs wraps 'documents' in an - # extra list, so there's a 3 level deep list, we need to flatten it to 2 levels only + # the _consume_component_inputs() creates a 3 level deep list for lazy variadic component + # we need to flatten it to 2 levels only if self.resume_state and self.resume_state["breakpoint"]["component"] == component_name: for socket_name, socket in component["input_sockets"].items(): if is_socket_lazy_variadic(socket): diff --git a/haystack_experimental/dataclasses/__init__.py b/haystack_experimental/dataclasses/__init__.py index 429ab99b..5af5fc85 100644 --- a/haystack_experimental/dataclasses/__init__.py +++ b/haystack_experimental/dataclasses/__init__.py @@ -10,9 +10,11 @@ _import_structure = { "chat_message": ["ChatMessage"], "image_content": ["ImageContent"], + "answer": ["Answer", "ExtractedAnswer", "GeneratedAnswer"], } if TYPE_CHECKING: + from .answer import Answer, ExtractedAnswer, GeneratedAnswer from .chat_message import ChatMessage from .image_content import ImageContent else: diff --git a/haystack_experimental/dataclasses/answer.py b/haystack_experimental/dataclasses/answer.py new file mode 100644 index 00000000..57d5cd5c --- /dev/null +++ b/haystack_experimental/dataclasses/answer.py @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import asdict, dataclass, field +from typing import Any, Dict, List, Optional, Protocol, runtime_checkable + +from haystack.core.serialization import default_from_dict, default_to_dict +from haystack.dataclasses import ChatMessage, Document + + +@runtime_checkable +@dataclass +class Answer(Protocol): + data: Any + query: str + meta: Dict[str, Any] + + def to_dict(self) -> Dict[str, Any]: # noqa: D102 + ... + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Answer": # noqa: D102 + ... + + +@dataclass +class ExtractedAnswer: + query: str + score: float + data: Optional[str] = None + document: Optional[Document] = None + context: Optional[str] = None + document_offset: Optional["Span"] = None + context_offset: Optional["Span"] = None + meta: Dict[str, Any] = field(default_factory=dict) + + @dataclass + class Span: + start: int + end: int + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize the object to a dictionary. + + :returns: + Serialized dictionary representation of the object. + """ + document = self.document.to_dict(flatten=False) if self.document is not None else None + document_offset = asdict(self.document_offset) if self.document_offset is not None else None + context_offset = asdict(self.context_offset) if self.context_offset is not None else None + return default_to_dict( + self, + data=self.data, + query=self.query, + document=document, + context=self.context, + score=self.score, + document_offset=document_offset, + context_offset=context_offset, + meta=self.meta, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ExtractedAnswer": + """ + Deserialize the object from a dictionary. + + :param data: + Dictionary representation of the object. + :returns: + Deserialized object. + """ + init_params = data.get("init_parameters", {}) + if (doc := init_params.get("document")) is not None: + data["init_parameters"]["document"] = Document.from_dict(doc) + + if (offset := init_params.get("document_offset")) is not None: + data["init_parameters"]["document_offset"] = ExtractedAnswer.Span(**offset) + + if (offset := init_params.get("context_offset")) is not None: + data["init_parameters"]["context_offset"] = ExtractedAnswer.Span(**offset) + return default_from_dict(cls, data) + + +@dataclass +class GeneratedAnswer: + data: str + query: str + documents: List[Document] + meta: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize the object to a dictionary. + + :returns: + Serialized dictionary representation of the object. + """ + documents = [doc.to_dict(flatten=False) for doc in self.documents] + + # Serialize ChatMessage objects to dicts + meta = self.meta + all_messages = meta.get("all_messages") + + # all_messages is either a list of ChatMessage objects or a list of strings + if all_messages and isinstance(all_messages[0], ChatMessage): + meta = {**meta, "all_messages": [msg.to_dict() for msg in all_messages]} + + return default_to_dict(self, data=self.data, query=self.query, documents=documents, meta=meta) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "GeneratedAnswer": + """ + Deserialize the object from a dictionary. + + :param data: + Dictionary representation of the object. + + :returns: + Deserialized object. + """ + init_params = data.get("init_parameters", {}) + + if (documents := init_params.get("documents")) is not None: + init_params["documents"] = [Document.from_dict(d) for d in documents] + + meta = init_params.get("meta", {}) + if (all_messages := meta.get("all_messages")) is not None and isinstance(all_messages[0], dict): + meta["all_messages"] = [ChatMessage.from_dict(m) for m in all_messages] + init_params["meta"] = meta + + return default_from_dict(cls, data) diff --git a/pyproject.toml b/pyproject.toml index 116b5fbb..2fc55313 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ classifiers = [ ] dependencies = [ - "haystack-ai @ git+https://github.com/deepset-ai/haystack.git@main", + "haystack-ai", "filetype", # for mime type detection in ImageContent ] @@ -52,7 +52,7 @@ fmt-check = "ruff check {args} && ruff format --check {args}" [tool.hatch.envs.test] extra-dependencies = [ "colorama", # Pipeline checkpoints experiment - "transformers[torch,sentencepiece]>=4.51.1,<4.52", # Pipeline checkpoints experiment + "transformers[torch, sentencepiece]>=4.51.1,<4.52", # Pipeline checkpoints experiment "arrow>=1.3.0", # Multimodal experiment - ChatPromptBuilder "pypdfium2", # Multimodal experiment - PDFToImageContent "pillow", # Multimodal experiment - ImageFileToImageContent, PDFToImageContent From 577837dc6926f79a64cd8af7dea03569b001432f Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Tue, 10 Jun 2025 16:22:49 +0200 Subject: [PATCH 08/30] Add answer builder --- .../components/builders/answer_builder.py | 200 ++++++++++++++++++ ...test_pipeline_breakpoints_answer_joiner.py | 2 +- .../test_pipeline_breakpoints_rag_hybrid.py | 2 +- 3 files changed, 202 insertions(+), 2 deletions(-) create mode 100644 haystack_experimental/components/builders/answer_builder.py diff --git a/haystack_experimental/components/builders/answer_builder.py b/haystack_experimental/components/builders/answer_builder.py new file mode 100644 index 00000000..c96f117c --- /dev/null +++ b/haystack_experimental/components/builders/answer_builder.py @@ -0,0 +1,200 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import re +from typing import Any, Dict, List, Optional, Union + +from haystack import Document, component, logging +from haystack.dataclasses.chat_message import ChatMessage + +from haystack_experimental.dataclasses import GeneratedAnswer + +logger = logging.getLogger(__name__) + + +@component +class AnswerBuilder: + """ + Converts a query and Generator replies into a `GeneratedAnswer` object. + + AnswerBuilder parses Generator replies using custom regular expressions. + Check out the usage example below to see how it works. + Optionally, it can also take documents and metadata from the Generator to add to the `GeneratedAnswer` object. + AnswerBuilder works with both non-chat and chat Generators. + + ### Usage example + + ```python + from haystack.components.builders import AnswerBuilder + + builder = AnswerBuilder(pattern="Answer: (.*)") + builder.run(query="What's the answer?", replies=["This is an argument. Answer: This is the answer."]) + ``` + """ + + def __init__( + self, pattern: Optional[str] = None, reference_pattern: Optional[str] = None, last_message_only: bool = False + ): + """ + Creates an instance of the AnswerBuilder component. + + :param pattern: + The regular expression pattern to extract the answer text from the Generator. + If not specified, the entire response is used as the answer. + The regular expression can have one capture group at most. + If present, the capture group text + is used as the answer. If no capture group is present, the whole match is used as the answer. + Examples: + `[^\\n]+$` finds "this is an answer" in a string "this is an argument.\\nthis is an answer". + `Answer: (.*)` finds "this is an answer" in a string "this is an argument. Answer: this is an answer". + + :param reference_pattern: + The regular expression pattern used for parsing the document references. + If not specified, no parsing is done, and all documents are referenced. + References need to be specified as indices of the input documents and start at [1]. + Example: `\\[(\\d+)\\]` finds "1" in a string "this is an answer[1]". + + :param last_message_only: + If False (default value), all messages are used as the answer. + If True, only the last message is used as the answer. + """ + if pattern: + AnswerBuilder._check_num_groups_in_regex(pattern) + + self.pattern = pattern + self.reference_pattern = reference_pattern + self.last_message_only = last_message_only + + @component.output_types(answers=List[GeneratedAnswer]) + def run( # pylint: disable=too-many-positional-arguments + self, + query: str, + replies: Union[List[str], List[ChatMessage]], + meta: Optional[List[Dict[str, Any]]] = None, + documents: Optional[List[Document]] = None, + pattern: Optional[str] = None, + reference_pattern: Optional[str] = None, + ): + """ + Turns the output of a Generator into `GeneratedAnswer` objects using regular expressions. + + :param query: + The input query used as the Generator prompt. + :param replies: + The output of the Generator. Can be a list of strings or a list of `ChatMessage` objects. + :param meta: + The metadata returned by the Generator. If not specified, the generated answer will contain no metadata. + :param documents: + The documents used as the Generator inputs. If specified, they are added to + the`GeneratedAnswer` objects. + If both `documents` and `reference_pattern` are specified, the documents referenced in the + Generator output are extracted from the input documents and added to the `GeneratedAnswer` objects. + :param pattern: + The regular expression pattern to extract the answer text from the Generator. + If not specified, the entire response is used as the answer. + The regular expression can have one capture group at most. + If present, the capture group text + is used as the answer. If no capture group is present, the whole match is used as the answer. + Examples: + `[^\\n]+$` finds "this is an answer" in a string "this is an argument.\\nthis is an answer". + `Answer: (.*)` finds "this is an answer" in a string + "this is an argument. Answer: this is an answer". + :param reference_pattern: + The regular expression pattern used for parsing the document references. + If not specified, no parsing is done, and all documents are referenced. + References need to be specified as indices of the input documents and start at [1]. + Example: `\\[(\\d+)\\]` finds "1" in a string "this is an answer[1]". + + :returns: A dictionary with the following keys: + - `answers`: The answers received from the output of the Generator. + """ + if not meta: + meta = [{}] * len(replies) + elif len(replies) != len(meta): + raise ValueError(f"Number of replies ({len(replies)}), and metadata ({len(meta)}) must match.") + + if pattern: + AnswerBuilder._check_num_groups_in_regex(pattern) + + pattern = pattern or self.pattern + reference_pattern = reference_pattern or self.reference_pattern + all_answers = [] + + replies_to_iterate = replies + meta_to_iterate = meta + + if self.last_message_only and replies: + replies_to_iterate = replies[-1:] + meta_to_iterate = meta[-1:] + + for reply, given_metadata in zip(replies_to_iterate, meta_to_iterate): + # Extract content from ChatMessage objects if reply is a ChatMessages, else use the string as is + if isinstance(reply, ChatMessage): + extracted_reply = reply.text or "" + else: + extracted_reply = str(reply) + extracted_metadata = reply.meta if isinstance(reply, ChatMessage) else {} + + extracted_metadata = {**extracted_metadata, **given_metadata} + extracted_metadata["all_messages"] = replies + + referenced_docs = [] + if documents: + if reference_pattern: + reference_idxs = AnswerBuilder._extract_reference_idxs(extracted_reply, reference_pattern) + else: + reference_idxs = [doc_idx for doc_idx, _ in enumerate(documents)] + + for idx in reference_idxs: + try: + referenced_docs.append(documents[idx]) + except IndexError: + logger.warning( + "Document index '{index}' referenced in Generator output is out of range. ", index=idx + 1 + ) + + answer_string = AnswerBuilder._extract_answer_string(extracted_reply, pattern) + answer = GeneratedAnswer( + data=answer_string, query=query, documents=referenced_docs, meta=extracted_metadata + ) + all_answers.append(answer) + + return {"answers": all_answers} + + @staticmethod + def _extract_answer_string(reply: str, pattern: Optional[str] = None) -> str: + """ + Extract the answer string from the generator output using the specified pattern. + + If no pattern is specified, the whole string is used as the answer. + + :param reply: + The output of the Generator. A string. + :param pattern: + The regular expression pattern to use to extract the answer text from the generator output. + """ + if pattern is None: + return reply + + if match := re.search(pattern, reply): + # No capture group in pattern -> use the whole match as answer + if not match.lastindex: + return match.group(0) + # One capture group in pattern -> use the capture group as answer + return match.group(1) + return "" + + @staticmethod + def _extract_reference_idxs(reply: str, reference_pattern: str) -> List[int]: + document_idxs = re.findall(reference_pattern, reply) + return [int(idx) - 1 for idx in document_idxs] + + @staticmethod + def _check_num_groups_in_regex(pattern: str): + num_groups = re.compile(pattern).groups + if num_groups > 1: + raise ValueError( + f"Pattern '{pattern}' contains multiple capture groups. " + f"Please specify a pattern with at most one capture group." + ) diff --git a/test/core/pipeline/test_pipeline_breakpoints_answer_joiner.py b/test/core/pipeline/test_pipeline_breakpoints_answer_joiner.py index a914cdf8..a80b4580 100644 --- a/test/core/pipeline/test_pipeline_breakpoints_answer_joiner.py +++ b/test/core/pipeline/test_pipeline_breakpoints_answer_joiner.py @@ -2,7 +2,7 @@ import pytest -from haystack.components.builders.answer_builder import AnswerBuilder +from haystack_experimental.components.builders.answer_builder import AnswerBuilder from haystack.components.generators.chat import OpenAIChatGenerator from haystack.components.joiners import AnswerJoiner from haystack.core.pipeline import Pipeline diff --git a/test/core/pipeline/test_pipeline_breakpoints_rag_hybrid.py b/test/core/pipeline/test_pipeline_breakpoints_rag_hybrid.py index 175e1749..8ebc47c8 100644 --- a/test/core/pipeline/test_pipeline_breakpoints_rag_hybrid.py +++ b/test/core/pipeline/test_pipeline_breakpoints_rag_hybrid.py @@ -2,7 +2,7 @@ import pytest -from haystack.components.builders.answer_builder import AnswerBuilder +from haystack_experimental.components.builders.answer_builder import AnswerBuilder from haystack.components.builders.prompt_builder import PromptBuilder from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder from haystack.components.generators import OpenAIGenerator From 35414dceb345587962f029e1ee9484a2d4e29a43 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Tue, 10 Jun 2025 16:37:59 +0200 Subject: [PATCH 09/30] Typing fix --- haystack_experimental/components/builders/answer_builder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/haystack_experimental/components/builders/answer_builder.py b/haystack_experimental/components/builders/answer_builder.py index c96f117c..eaf1b05c 100644 --- a/haystack_experimental/components/builders/answer_builder.py +++ b/haystack_experimental/components/builders/answer_builder.py @@ -75,7 +75,7 @@ def run( # pylint: disable=too-many-positional-arguments documents: Optional[List[Document]] = None, pattern: Optional[str] = None, reference_pattern: Optional[str] = None, - ): + ) -> Dict[str, List[GeneratedAnswer]]: """ Turns the output of a Generator into `GeneratedAnswer` objects using regular expressions. @@ -191,7 +191,7 @@ def _extract_reference_idxs(reply: str, reference_pattern: str) -> List[int]: return [int(idx) - 1 for idx in document_idxs] @staticmethod - def _check_num_groups_in_regex(pattern: str): + def _check_num_groups_in_regex(pattern: str) -> None: num_groups = re.compile(pattern).groups if num_groups > 1: raise ValueError( From 8e826129e46fec2a87dff111ca4a56493941746f Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Wed, 11 Jun 2025 13:13:29 +0200 Subject: [PATCH 10/30] PR comments --- .../components/builders/answer_builder.py | 73 +--------------- haystack_experimental/dataclasses/__init__.py | 4 +- haystack_experimental/dataclasses/answer.py | 87 +------------------ ...test_pipeline_breakpoints_answer_joiner.py | 2 +- .../test_pipeline_breakpoints_rag_hybrid.py | 2 +- 5 files changed, 10 insertions(+), 158 deletions(-) diff --git a/haystack_experimental/components/builders/answer_builder.py b/haystack_experimental/components/builders/answer_builder.py index eaf1b05c..06ea9deb 100644 --- a/haystack_experimental/components/builders/answer_builder.py +++ b/haystack_experimental/components/builders/answer_builder.py @@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Union from haystack import Document, component, logging +from haystack.components.builders.answer_builder import AnswerBuilder as HaystackAnswerBuilder from haystack.dataclasses.chat_message import ChatMessage from haystack_experimental.dataclasses import GeneratedAnswer @@ -14,7 +15,7 @@ @component -class AnswerBuilder: +class AnswerBuilder(HaystackAnswerBuilder): """ Converts a query and Generator replies into a `GeneratedAnswer` object. @@ -33,39 +34,6 @@ class AnswerBuilder: ``` """ - def __init__( - self, pattern: Optional[str] = None, reference_pattern: Optional[str] = None, last_message_only: bool = False - ): - """ - Creates an instance of the AnswerBuilder component. - - :param pattern: - The regular expression pattern to extract the answer text from the Generator. - If not specified, the entire response is used as the answer. - The regular expression can have one capture group at most. - If present, the capture group text - is used as the answer. If no capture group is present, the whole match is used as the answer. - Examples: - `[^\\n]+$` finds "this is an answer" in a string "this is an argument.\\nthis is an answer". - `Answer: (.*)` finds "this is an answer" in a string "this is an argument. Answer: this is an answer". - - :param reference_pattern: - The regular expression pattern used for parsing the document references. - If not specified, no parsing is done, and all documents are referenced. - References need to be specified as indices of the input documents and start at [1]. - Example: `\\[(\\d+)\\]` finds "1" in a string "this is an answer[1]". - - :param last_message_only: - If False (default value), all messages are used as the answer. - If True, only the last message is used as the answer. - """ - if pattern: - AnswerBuilder._check_num_groups_in_regex(pattern) - - self.pattern = pattern - self.reference_pattern = reference_pattern - self.last_message_only = last_message_only - @component.output_types(answers=List[GeneratedAnswer]) def run( # pylint: disable=too-many-positional-arguments self, @@ -161,40 +129,3 @@ def run( # pylint: disable=too-many-positional-arguments all_answers.append(answer) return {"answers": all_answers} - - @staticmethod - def _extract_answer_string(reply: str, pattern: Optional[str] = None) -> str: - """ - Extract the answer string from the generator output using the specified pattern. - - If no pattern is specified, the whole string is used as the answer. - - :param reply: - The output of the Generator. A string. - :param pattern: - The regular expression pattern to use to extract the answer text from the generator output. - """ - if pattern is None: - return reply - - if match := re.search(pattern, reply): - # No capture group in pattern -> use the whole match as answer - if not match.lastindex: - return match.group(0) - # One capture group in pattern -> use the capture group as answer - return match.group(1) - return "" - - @staticmethod - def _extract_reference_idxs(reply: str, reference_pattern: str) -> List[int]: - document_idxs = re.findall(reference_pattern, reply) - return [int(idx) - 1 for idx in document_idxs] - - @staticmethod - def _check_num_groups_in_regex(pattern: str) -> None: - num_groups = re.compile(pattern).groups - if num_groups > 1: - raise ValueError( - f"Pattern '{pattern}' contains multiple capture groups. " - f"Please specify a pattern with at most one capture group." - ) diff --git a/haystack_experimental/dataclasses/__init__.py b/haystack_experimental/dataclasses/__init__.py index 5af5fc85..bc7446d4 100644 --- a/haystack_experimental/dataclasses/__init__.py +++ b/haystack_experimental/dataclasses/__init__.py @@ -10,11 +10,11 @@ _import_structure = { "chat_message": ["ChatMessage"], "image_content": ["ImageContent"], - "answer": ["Answer", "ExtractedAnswer", "GeneratedAnswer"], + "answer": ["GeneratedAnswer"], } if TYPE_CHECKING: - from .answer import Answer, ExtractedAnswer, GeneratedAnswer + from .answer import GeneratedAnswer from .chat_message import ChatMessage from .image_content import ImageContent else: diff --git a/haystack_experimental/dataclasses/answer.py b/haystack_experimental/dataclasses/answer.py index 57d5cd5c..eaed7c49 100644 --- a/haystack_experimental/dataclasses/answer.py +++ b/haystack_experimental/dataclasses/answer.py @@ -2,95 +2,16 @@ # # SPDX-License-Identifier: Apache-2.0 -from dataclasses import asdict, dataclass, field -from typing import Any, Dict, List, Optional, Protocol, runtime_checkable +from dataclasses import dataclass +from typing import Any, Dict from haystack.core.serialization import default_from_dict, default_to_dict from haystack.dataclasses import ChatMessage, Document +from haystack.dataclasses import GeneratedAnswer as HaystackGeneratedAnswer -@runtime_checkable @dataclass -class Answer(Protocol): - data: Any - query: str - meta: Dict[str, Any] - - def to_dict(self) -> Dict[str, Any]: # noqa: D102 - ... - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "Answer": # noqa: D102 - ... - - -@dataclass -class ExtractedAnswer: - query: str - score: float - data: Optional[str] = None - document: Optional[Document] = None - context: Optional[str] = None - document_offset: Optional["Span"] = None - context_offset: Optional["Span"] = None - meta: Dict[str, Any] = field(default_factory=dict) - - @dataclass - class Span: - start: int - end: int - - def to_dict(self) -> Dict[str, Any]: - """ - Serialize the object to a dictionary. - - :returns: - Serialized dictionary representation of the object. - """ - document = self.document.to_dict(flatten=False) if self.document is not None else None - document_offset = asdict(self.document_offset) if self.document_offset is not None else None - context_offset = asdict(self.context_offset) if self.context_offset is not None else None - return default_to_dict( - self, - data=self.data, - query=self.query, - document=document, - context=self.context, - score=self.score, - document_offset=document_offset, - context_offset=context_offset, - meta=self.meta, - ) - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ExtractedAnswer": - """ - Deserialize the object from a dictionary. - - :param data: - Dictionary representation of the object. - :returns: - Deserialized object. - """ - init_params = data.get("init_parameters", {}) - if (doc := init_params.get("document")) is not None: - data["init_parameters"]["document"] = Document.from_dict(doc) - - if (offset := init_params.get("document_offset")) is not None: - data["init_parameters"]["document_offset"] = ExtractedAnswer.Span(**offset) - - if (offset := init_params.get("context_offset")) is not None: - data["init_parameters"]["context_offset"] = ExtractedAnswer.Span(**offset) - return default_from_dict(cls, data) - - -@dataclass -class GeneratedAnswer: - data: str - query: str - documents: List[Document] - meta: Dict[str, Any] = field(default_factory=dict) - +class GeneratedAnswer(HaystackGeneratedAnswer): def to_dict(self) -> Dict[str, Any]: """ Serialize the object to a dictionary. diff --git a/test/core/pipeline/test_pipeline_breakpoints_answer_joiner.py b/test/core/pipeline/test_pipeline_breakpoints_answer_joiner.py index a80b4580..fbf65f27 100644 --- a/test/core/pipeline/test_pipeline_breakpoints_answer_joiner.py +++ b/test/core/pipeline/test_pipeline_breakpoints_answer_joiner.py @@ -66,7 +66,7 @@ def answer_join_pipeline(self, mock_openai_chat_generator): Creates a pipeline with mocked OpenAI components. """ # Create the pipeline with mocked components - pipeline = Pipeline() + pipeline = Pipeline(connection_type_validation=False) pipeline.add_component("gpt-4o", mock_openai_chat_generator("gpt-4o")) pipeline.add_component("gpt-3", mock_openai_chat_generator("gpt-3.5-turbo")) pipeline.add_component("answer_builder_a", AnswerBuilder()) diff --git a/test/core/pipeline/test_pipeline_breakpoints_rag_hybrid.py b/test/core/pipeline/test_pipeline_breakpoints_rag_hybrid.py index 8ebc47c8..f0570063 100644 --- a/test/core/pipeline/test_pipeline_breakpoints_rag_hybrid.py +++ b/test/core/pipeline/test_pipeline_breakpoints_rag_hybrid.py @@ -208,7 +208,7 @@ def hybrid_rag_pipeline(self, document_store, mock_transformers_similarity_ranke \nQuestion: {{question}} \nAnswer: """ - pipeline = Pipeline() + pipeline = Pipeline(connection_type_validation=False) pipeline.add_component(instance=InMemoryBM25Retriever(document_store=document_store), name="bm25_retriever") # Use the mocked embedder instead of creating a new one From 0d54ed88d7fac47dabda7b19bfef2abfd3a237af Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Wed, 11 Jun 2025 13:14:46 +0200 Subject: [PATCH 11/30] Small fix to PR --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 2fc55313..81cfc45a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ fmt-check = "ruff check {args} && ruff format --check {args}" [tool.hatch.envs.test] extra-dependencies = [ "colorama", # Pipeline checkpoints experiment - "transformers[torch, sentencepiece]>=4.51.1,<4.52", # Pipeline checkpoints experiment + "transformers[torch,sentencepiece]>=4.51.1,<4.52", # Pipeline checkpoints experiment "arrow>=1.3.0", # Multimodal experiment - ChatPromptBuilder "pypdfium2", # Multimodal experiment - PDFToImageContent "pillow", # Multimodal experiment - ImageFileToImageContent, PDFToImageContent From 8ecb77dbdef131552f56a132df37c890fac7ee31 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Wed, 11 Jun 2025 13:20:58 +0200 Subject: [PATCH 12/30] Fix linting --- haystack_experimental/components/builders/answer_builder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/haystack_experimental/components/builders/answer_builder.py b/haystack_experimental/components/builders/answer_builder.py index 06ea9deb..787e9393 100644 --- a/haystack_experimental/components/builders/answer_builder.py +++ b/haystack_experimental/components/builders/answer_builder.py @@ -2,7 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 -import re from typing import Any, Dict, List, Optional, Union from haystack import Document, component, logging From e4985da4b9f29f57b58cc7d84e5a386b971e12f2 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Wed, 11 Jun 2025 14:45:36 +0200 Subject: [PATCH 13/30] FIx for consume inputs --- haystack_experimental/core/pipeline/base.py | 10 +++++-- .../core/pipeline/pipeline.py | 28 ++++++++++++++----- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/haystack_experimental/core/pipeline/base.py b/haystack_experimental/core/pipeline/base.py index df1423f1..a832cc13 100644 --- a/haystack_experimental/core/pipeline/base.py +++ b/haystack_experimental/core/pipeline/base.py @@ -973,7 +973,9 @@ def _convert_to_internal_format(pipeline_inputs: Dict[str, Any]) -> Dict[str, Di return inputs @staticmethod - def _consume_component_inputs(component_name: str, component: Dict, inputs: Dict) -> Dict[str, Any]: + def _consume_component_inputs( + component_name: str, component: Dict, inputs: Dict, is_resume: bool = False + ) -> Dict[str, Any]: """ Extracts the inputs needed to run for the component and removes them from the global inputs state. @@ -988,7 +990,11 @@ def _consume_component_inputs(component_name: str, component: Dict, inputs: Dict for socket_name, socket in component["input_sockets"].items(): socket_inputs = component_inputs.get(socket_name, []) socket_inputs = [sock["value"] for sock in socket_inputs if sock["value"] is not _NO_OUTPUT_PRODUCED] - if socket_inputs: + + # if we are resuming a component, the inputs are already consumed, so we just return the first input + if is_resume: + consumed_inputs[socket_name] = socket_inputs[0] + elif socket_inputs: if not socket.is_variadic: # We only care about the first input provided to the socket. consumed_inputs[socket_name] = socket_inputs[0] diff --git a/haystack_experimental/core/pipeline/pipeline.py b/haystack_experimental/core/pipeline/pipeline.py index 5b0bdf0f..3ca391cf 100644 --- a/haystack_experimental/core/pipeline/pipeline.py +++ b/haystack_experimental/core/pipeline/pipeline.py @@ -277,9 +277,15 @@ def run( # noqa: PLR0915, PLR0912 component_name, component_visits[component_name] ) - component_inputs = self._consume_component_inputs( - component_name=component_name, component=component, inputs=inputs - ) + if self.resume_state and self.resume_state["breakpoint"]["component"] == component_name: + component_inputs = self._consume_component_inputs( + component_name=component_name, component=component, inputs=inputs, is_resume=True + ) + else: + component_inputs = self._consume_component_inputs( + component_name=component_name, component=component, inputs=inputs + ) + # We need to add missing defaults using default values from input sockets because the run signature # might not provide these defaults for components with inputs defined dynamically upon component # initialization @@ -304,6 +310,12 @@ def run( # noqa: PLR0915, PLR0912 if not key.startswith("__") } # type: ignore[assignment] + if "_template_string" in state_inputs_serialised[component_name]["init_parameters"]: + state_inputs_serialised[component_name]["init_parameters"]["template"] = ( + state_inputs_serialised[component_name]["init_parameters"]["_template_string"] + ) + state_inputs_serialised[component_name]["init_parameters"].pop("_template_string") + Pipeline._check_breakpoints( breakpoints=validated_breakpoints, component_name=component_name, @@ -316,10 +328,12 @@ def run( # noqa: PLR0915, PLR0912 # the _consume_component_inputs() creates a 3 level deep list for lazy variadic component # we need to flatten it to 2 levels only - if self.resume_state and self.resume_state["breakpoint"]["component"] == component_name: - for socket_name, socket in component["input_sockets"].items(): - if is_socket_lazy_variadic(socket): - component_inputs[socket_name] = component_inputs[socket_name][0] + # if self.resume_state and self.resume_state["breakpoint"]["component"] == component_name: + # for socket_name, socket in component["input_sockets"].items(): + # if is_socket_lazy_variadic(socket): + # print ("I am a lazy variadic component") + # print (component_inputs[socket_name]) + # component_inputs[socket_name] = component_inputs[socket_name][0] component_outputs = self._run_component( component_name=component_name, From 83e4188b7bdb12841180200c01cdd497864c7a04 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Wed, 11 Jun 2025 14:47:54 +0200 Subject: [PATCH 14/30] Fix linting --- haystack_experimental/core/pipeline/pipeline.py | 1 - 1 file changed, 1 deletion(-) diff --git a/haystack_experimental/core/pipeline/pipeline.py b/haystack_experimental/core/pipeline/pipeline.py index 3ca391cf..c0e38267 100644 --- a/haystack_experimental/core/pipeline/pipeline.py +++ b/haystack_experimental/core/pipeline/pipeline.py @@ -13,7 +13,6 @@ from haystack import Answer, Document, ExtractedAnswer, logging, tracing from haystack.core.component import Component -from haystack.core.pipeline.component_checks import is_socket_lazy_variadic from haystack.dataclasses import ChatMessage, SparseEmbedding from haystack.telemetry import pipeline_running From a08bb16be5f1290b3354e4a40c2fd5f1e780e9e5 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Wed, 11 Jun 2025 15:18:39 +0200 Subject: [PATCH 15/30] Fix typing --- haystack_experimental/core/pipeline/pipeline.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/haystack_experimental/core/pipeline/pipeline.py b/haystack_experimental/core/pipeline/pipeline.py index c0e38267..69939e32 100644 --- a/haystack_experimental/core/pipeline/pipeline.py +++ b/haystack_experimental/core/pipeline/pipeline.py @@ -303,17 +303,17 @@ def run( # noqa: PLR0915, PLR0912 state_inputs_serialised[component_name] = deepcopy(component_inputs) # we use dict instead of to_dict() because it strips away class types of init params - state_inputs_serialised[component_name]["init_parameters"] = { + init_params = { key: value for key, value in component["instance"].__dict__.items() if not key.startswith("__") - } # type: ignore[assignment] + } - if "_template_string" in state_inputs_serialised[component_name]["init_parameters"]: - state_inputs_serialised[component_name]["init_parameters"]["template"] = ( - state_inputs_serialised[component_name]["init_parameters"]["_template_string"] - ) - state_inputs_serialised[component_name]["init_parameters"].pop("_template_string") + if "_template_string" in init_params: + init_params["template"] = init_params["_template_string"] + init_params.pop("_template_string") + + state_inputs_serialised[component_name]["init_parameters"] = init_params # type: ignore[assignment] Pipeline._check_breakpoints( breakpoints=validated_breakpoints, From cd90b6376a6a13f8d01a759f6b936eb61fb6a56e Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Wed, 11 Jun 2025 18:41:21 +0200 Subject: [PATCH 16/30] Remove check_breakpoint --- .../core/pipeline/pipeline.py | 254 ++++++++---------- test/core/pipeline/test_pipeline.py | 44 +-- ...test_pipeline_breakpoints_answer_joiner.py | 2 +- ...test_pipeline_breakpoints_branch_joiner.py | 2 +- .../test_pipeline_breakpoints_list_joiner.py | 2 +- .../test_pipeline_breakpoints_loops.py | 2 +- .../test_pipeline_breakpoints_rag_hybrid.py | 2 +- ...test_pipeline_breakpoints_string_joiner.py | 2 +- 8 files changed, 135 insertions(+), 175 deletions(-) diff --git a/haystack_experimental/core/pipeline/pipeline.py b/haystack_experimental/core/pipeline/pipeline.py index 69939e32..58a1907c 100644 --- a/haystack_experimental/core/pipeline/pipeline.py +++ b/haystack_experimental/core/pipeline/pipeline.py @@ -92,7 +92,7 @@ def run( # noqa: PLR0915, PLR0912 self, data: Dict[str, Any], include_outputs_from: Optional[Set[str]] = None, - breakpoints: Optional[Set[Tuple[str, Optional[int]]]] = None, + pipeline_breakpoint: Optional[Tuple[str, Optional[int]]] = None, resume_state: Optional[Dict[str, Any]] = None, debug_path: Optional[Union[str, Path]] = None, ) -> Dict[str, Any]: @@ -171,8 +171,8 @@ def run( # noqa: PLR0915, PLR0912 invoked multiple times (in a loop), only the last-produced output is included. - :param breakpoints: - Set of tuples of component names and visit counts at which the pipeline should break execution. + :param pipeline_breakpoint: + Tuple of component name and visit count at which the pipeline should break execution. If the visit count is not given, it is assumed to be 0, it will break on the first visit. :param resume_state: @@ -196,19 +196,19 @@ def run( # noqa: PLR0915, PLR0912 :raises PipelineMaxComponentRuns: If a Component reaches the maximum number of times it can be run in this Pipeline. :raises PipelineBreakpointException: - When a breakpoint is triggered. Contains the component name, state, and partial results. + When a pipeline_breakpoint is triggered. Contains the component name, state, and partial results. """ pipeline_running(self) - if breakpoints and resume_state: + if pipeline_breakpoint and resume_state: logger.warning( - "Breakpoints cannot be provided when resuming a pipeline. All breakpoints will be ignored.", + "pipeline_breakpoint will be ignored because it cannot be provided when resuming a pipeline.", ) self.debug_path = debug_path self.resume_state = resume_state - # make sure breakpoints are valid and have a default visit count - validated_breakpoints = self._validate_breakpoints(breakpoints) if breakpoints else None + # make sure pipeline_breakpoint is valid and have a default visit count + validated_breakpoint = self._validate_breakpoint(pipeline_breakpoint) if pipeline_breakpoint else None # TODO: Remove this warmup once we can check reliably whether a component has been warmed up or not # As of now it's here to make sure we don't have failing tests that assume warm_up() is called in run() @@ -256,49 +256,56 @@ def run( # noqa: PLR0915, PLR0912 # check if pipeline is blocked before execution self.validate_pipeline(priority_queue) - try: - while True: - candidate = self._get_next_runnable_component(priority_queue, component_visits) - if candidate is None: - break - - priority, component_name, component = candidate - - if len(priority_queue) > 0 and priority in [ComponentPriority.DEFER, ComponentPriority.DEFER_LAST]: - component_name, topological_sort = self._tiebreak_waiting_components( - component_name=component_name, - priority=priority, - priority_queue=priority_queue, - topological_sort=cached_topological_sort, - ) - cached_topological_sort = topological_sort - component = self._get_component_with_graph_metadata_and_visits( - component_name, component_visits[component_name] - ) + while True: + candidate = self._get_next_runnable_component(priority_queue, component_visits) + if candidate is None: + break - if self.resume_state and self.resume_state["breakpoint"]["component"] == component_name: - component_inputs = self._consume_component_inputs( - component_name=component_name, component=component, inputs=inputs, is_resume=True - ) - else: - component_inputs = self._consume_component_inputs( - component_name=component_name, component=component, inputs=inputs - ) + priority, component_name, component = candidate + + if len(priority_queue) > 0 and priority in [ComponentPriority.DEFER, ComponentPriority.DEFER_LAST]: + component_name, topological_sort = self._tiebreak_waiting_components( + component_name=component_name, + priority=priority, + priority_queue=priority_queue, + topological_sort=cached_topological_sort, + ) + cached_topological_sort = topological_sort + component = self._get_component_with_graph_metadata_and_visits( + component_name, component_visits[component_name] + ) + + is_resume = bool( + self.resume_state and self.resume_state["pipeline_breakpoint"]["component"] == component_name + ) + component_inputs = self._consume_component_inputs( + component_name=component_name, component=component, inputs=inputs, is_resume=is_resume + ) - # We need to add missing defaults using default values from input sockets because the run signature - # might not provide these defaults for components with inputs defined dynamically upon component - # initialization - component_inputs = self._add_missing_input_defaults(component_inputs, component["input_sockets"]) + # We need to add missing defaults using default values from input sockets because the run signature + # might not provide these defaults for components with inputs defined dynamically upon component + # initialization + component_inputs = self._add_missing_input_defaults(component_inputs, component["input_sockets"]) - # Deserialize the component_inputs if they are passed in resume state - # this check will prevent other component_inputs generated at runtime from being deserialized - if self.resume_state and component_name in self.resume_state["pipeline_state"]["inputs"].keys(): - for key, value in component_inputs.items(): - component_inputs[key] = _deserialize_component_input(value) + # Scenario 1: Resume state is provided to resume the pipeline at a specific component - if validated_breakpoints and not self.resume_state: + # Deserialize the component_inputs if they are passed in resume state + # this check will prevent other component_inputs generated at runtime from being deserialized + if self.resume_state and component_name in self.resume_state["pipeline_state"]["inputs"].keys(): + for key, value in component_inputs.items(): + component_inputs[key] = _deserialize_component_input(value) + + # Scenario 2: pipeline_breakpoint is provided to stop the pipeline at + # a specific component and visit count + + if validated_breakpoint: + breakpoint_component, visit_count = validated_breakpoint + breakpoint_triggered = bool( + breakpoint_component == component_name and visit_count == component_visits[component_name] + ) + if breakpoint_triggered: state_inputs_serialised = deepcopy(inputs) - # we store the init params for the component with breakpoint in debug state + # we store the init params for the component with pipeline_breakpoint in debug state # this is helpful for retaining the state of the component and manual debugging state_inputs_serialised[component_name] = deepcopy(component_inputs) @@ -315,57 +322,50 @@ def run( # noqa: PLR0915, PLR0912 state_inputs_serialised[component_name]["init_parameters"] = init_params # type: ignore[assignment] - Pipeline._check_breakpoints( - breakpoints=validated_breakpoints, - component_name=component_name, - component_visits=component_visits, + Pipeline._save_state( inputs=state_inputs_serialised, + component_name=str(component_name), + component_visits=component_visits, debug_path=self.debug_path, original_input_data=data, ordered_component_names=self.ordered_component_names, ) + msg = f"Breaking at component {component_name} visit count {component_visits[component_name]}" + logger.info(msg) + raise PipelineBreakpointException( + message=msg, + component=component_name, + state=state_inputs_serialised, + results=pipeline_outputs, + ) - # the _consume_component_inputs() creates a 3 level deep list for lazy variadic component - # we need to flatten it to 2 levels only - # if self.resume_state and self.resume_state["breakpoint"]["component"] == component_name: - # for socket_name, socket in component["input_sockets"].items(): - # if is_socket_lazy_variadic(socket): - # print ("I am a lazy variadic component") - # print (component_inputs[socket_name]) - # component_inputs[socket_name] = component_inputs[socket_name][0] - - component_outputs = self._run_component( - component_name=component_name, - component=component, - inputs=component_inputs, # the inputs to the current component - component_visits=component_visits, - parent_span=span, - ) - - # Updates global input state with component outputs and returns outputs that should go to - # pipeline outputs. - component_pipeline_outputs = self._write_component_outputs( - component_name=component_name, - component_outputs=component_outputs, - inputs=inputs, - receivers=cached_receivers[component_name], - include_outputs_from=include_outputs_from, - ) + component_outputs = self._run_component( + component_name=component_name, + component=component, + inputs=component_inputs, # the inputs to the current component + component_visits=component_visits, + parent_span=span, + ) - if component_pipeline_outputs: - pipeline_outputs[component_name] = deepcopy(component_pipeline_outputs) - if self._is_queue_stale(priority_queue): - priority_queue = self._fill_queue(self.ordered_component_names, inputs, component_visits) + # Updates global input state with component outputs and returns outputs that should go to + # pipeline outputs. + component_pipeline_outputs = self._write_component_outputs( + component_name=component_name, + component_outputs=component_outputs, + inputs=inputs, + receivers=cached_receivers[component_name], + include_outputs_from=include_outputs_from, + ) - except PipelineBreakpointException as e: - # Add the current pipeline results to the exception - e.results = pipeline_outputs - raise + if component_pipeline_outputs: + pipeline_outputs[component_name] = deepcopy(component_pipeline_outputs) + if self._is_queue_stale(priority_queue): + priority_queue = self._fill_queue(self.ordered_component_names, inputs, component_visits) - if breakpoints: - logger.warning(f"Given breakpoint {breakpoints} was never triggered. This is because:") + if pipeline_breakpoint: + logger.warning(f"Given pipeline_breakpoint {pipeline_breakpoint} was never triggered. This is because:") logger.warning("1. The provided component is not a part of the pipeline execution path.") - logger.warning("2. The component did not reach the visit count specified in the breakpoint") + logger.warning("2. The component did not reach the visit count specified in the pipeline_breakpoint") return pipeline_outputs def inject_resume_state_into_graph( @@ -385,32 +385,32 @@ def inject_resume_state_into_graph( component_visits = self.resume_state["pipeline_state"]["component_visits"] self.ordered_component_names = self.resume_state["pipeline_state"]["ordered_component_names"] msg = ( - f"Resuming pipeline from {self.resume_state['breakpoint']['component']} " - f"visit count {self.resume_state['breakpoint']['visits']}" + f"Resuming pipeline from {self.resume_state['pipeline_breakpoint']['component']} " + f"visit count {self.resume_state['pipeline_breakpoint']['visits']}" ) logger.info(msg) return component_visits, data - def _validate_breakpoints(self, breakpoints: Set[Tuple[str, Optional[int]]]) -> Set[Tuple[str, int]]: + def _validate_breakpoint(self, pipeline_breakpoint: Tuple[str, Optional[int]]) -> Tuple[str, int]: """ - Validates the breakpoints passed to the pipeline. + Validates the pipeline_breakpoint passed to the pipeline. Make sure they are all valid components registered in the pipeline, If the visit is not given, it is assumed to be 0, it will break on the first visit. - :param breakpoints: Set of tuples of component names and visit counts at which the pipeline should stop. + :param pipeline_breakpoint: Tuple of component name and visit count at which the pipeline should stop. :returns: - Set of valid breakpoints. + Tuple of valid pipeline_breakpoint. """ - processed_breakpoints: Set[Tuple[str, int]] = set() - - for break_point in breakpoints: - if break_point[0] not in self.graph.nodes: - raise ValueError(f"Breakpoint {break_point} is not a registered component in the pipeline") - valid_breakpoint: Tuple[str, int] = (break_point[0], 0 if break_point[1] is None else break_point[1]) - processed_breakpoints.add(valid_breakpoint) - return processed_breakpoints + if pipeline_breakpoint and pipeline_breakpoint[0] not in self.graph.nodes: + raise ValueError(f"pipeline_breakpoint {pipeline_breakpoint} is not a registered component in the pipeline") + valid_breakpoint: Tuple[str, int] = ( + (pipeline_breakpoint[0], 0 if pipeline_breakpoint[1] is None else pipeline_breakpoint[1]) + if pipeline_breakpoint + else None + ) + return valid_breakpoint def _validate_pipeline_state(self, resume_state: Dict[str, Any]) -> None: """ @@ -449,8 +449,8 @@ def _validate_pipeline_state(self, resume_state: Dict[str, Any]) -> None: ) logger.info( - f"Resuming pipeline from component: {resume_state['breakpoint']['component']} " - f"(visit {resume_state['breakpoint']['visits']})" + f"Resuming pipeline from component: {resume_state['pipeline_breakpoint']['component']} " + f"(visit {resume_state['pipeline_breakpoint']['visits']})" ) @staticmethod @@ -458,14 +458,14 @@ def _validate_resume_state(state: Dict[str, Any]) -> None: """ Validates the loaded pipeline state. - Ensures that the state contains required keys: "input_data", "breakpoint", and "pipeline_state". + Ensures that the state contains required keys: "input_data", "pipeline_breakpoint", and "pipeline_state". Raises: ValueError: If required keys are missing or the component sets are inconsistent. """ # top-level state has all required keys - required_top_keys = {"input_data", "breakpoint", "pipeline_state"} + required_top_keys = {"input_data", "pipeline_breakpoint", "pipeline_state"} missing_top = required_top_keys - state.keys() if missing_top: raise ValueError(f"Invalid state file: missing required keys {missing_top}") @@ -520,7 +520,7 @@ def load_state(file_path: Union[str, Path]) -> Dict[str, Any]: return state @staticmethod - def save_state( + def _save_state( inputs: Dict[str, Any], component_name: str, component_visits: Dict[str, int], @@ -555,7 +555,7 @@ def save_state( state = { "input_data": _serialize_component_input(original_input_data), # original input data "timestamp": dt.isoformat(), - "breakpoint": {"component": component_name, "visits": component_visits[component_name]}, + "pipeline_breakpoint": {"component": component_name, "visits": component_visits[component_name]}, "pipeline_state": { "inputs": _serialize_component_input(inputs), # current pipeline state inputs "component_visits": component_visits, @@ -587,46 +587,6 @@ def save_state( logger.error(f"Failed to save pipeline state: {str(e)}") raise - @staticmethod - def _check_breakpoints( - breakpoints: Set[Tuple[str, int]], - component_name: str, - component_visits: Dict[str, int], - inputs: Dict[str, Any], - debug_path: Optional[Union[str, Path]] = None, - original_input_data: Optional[Dict[str, Any]] = None, - ordered_component_names: Optional[List[str]] = None, - ) -> None: - """ - Check if the `component_name` is in the breakpoints and if it should break. - - :param breakpoints: Set of tuples of component names and visit counts at which the pipeline should stop. - :param component_name: Name of the component to check. - :param component_visits: The number of times the component has been visited. - :param inputs: The inputs to the pipeline. - :param debug_path: The file path where debug state is saved. - :param original_input_data: The original input data to the pipeline. - :param ordered_component_names: The ordered component names in the pipeline. - - :raises PipelineBreakpointException: When a breakpoint is triggered, with component state information. - """ - for component, visit_count in breakpoints: - if component != component_name: - continue - - if visit_count == component_visits[component_name]: - msg = f"Breaking at component {component_name} visit count {component_visits[component_name]}" - logger.info(msg) - state = Pipeline.save_state( - inputs=inputs, - component_name=str(component_name), - component_visits=component_visits, - debug_path=debug_path, - original_input_data=original_input_data, - ordered_component_names=ordered_component_names, - ) - raise PipelineBreakpointException(msg, component=component_name, state=state) - def _deserialize_component_input(value: Any) -> Any: # noqa: PLR0911 """ diff --git a/test/core/pipeline/test_pipeline.py b/test/core/pipeline/test_pipeline.py index 91cfb6b7..be88e0f9 100644 --- a/test/core/pipeline/test_pipeline.py +++ b/test/core/pipeline/test_pipeline.py @@ -114,7 +114,7 @@ def test_run(self): _ = pp.run({"value": "test_value"}) - def test_validate_breakpoints(self): + def test_validate_breakpoint(self): # simple pipeline joiner_1 = BranchJoiner(type_=str) joiner_2 = BranchJoiner(type_=str) @@ -124,29 +124,29 @@ def test_validate_breakpoints(self): pipeline.connect("comp1", "comp2") # valid breakpoints - breakpoints = {("comp1", 0), ("comp2", 1)} - validated = pipeline._validate_breakpoints(breakpoints) - assert validated == {("comp1", 0), ("comp2", 1)} + breakpoints = ("comp1", 0) + validated = pipeline._validate_breakpoint(breakpoints) + assert validated == ("comp1", 0) # should default to 0 - breakpoints = {("comp1", None), ("comp2", 1)} - validated = pipeline._validate_breakpoints(breakpoints) - assert validated == {("comp1", 0), ("comp2", 1)} + breakpoints = ("comp1", None) + validated = pipeline._validate_breakpoint(breakpoints) + assert validated == ("comp1", 0) # should remain as it is - breakpoints = {("comp1", -1)} - validated = pipeline._validate_breakpoints(breakpoints) - assert validated == {("comp1", -1)} + breakpoints = ("comp1", -1) + validated = pipeline._validate_breakpoint(breakpoints) + assert validated == ("comp1", -1) # contains invalid components - breakpoints = {("comp1", 0), ("non_existent_component", 1)} - with pytest.raises(ValueError, match="Breakpoint .* is not a registered component"): - pipeline._validate_breakpoints(breakpoints) + breakpoints = ("comp3", 0) + with pytest.raises(ValueError, match="pipeline_breakpoint .* is not a registered component"): + pipeline._validate_breakpoint(breakpoints) # no breakpoints are defined - breakpoints = set() - validated = pipeline._validate_breakpoints(breakpoints) - assert validated == set() + breakpoint = None + validated = pipeline._validate_breakpoint(breakpoint) + assert validated is None def test_transform_json_structure_unwraps_sender_value(): @@ -307,7 +307,7 @@ def test_deserialize_component_input_handles_empty_structures(): def test_validate_resume_state_validates_required_keys(): state = { "input_data": {}, - "breakpoint": {"component": "comp1", "visits": 0} + "pipeline_breakpoint": {"component": "comp1", "visits": 0} # Missing pipeline_state } @@ -316,7 +316,7 @@ def test_validate_resume_state_validates_required_keys(): state = { "input_data": {}, - "breakpoint": {"component": "comp1", "visits": 0}, + "pipeline_breakpoint": {"component": "comp1", "visits": 0}, "pipeline_state": { "inputs": {}, "component_visits": {} @@ -330,7 +330,7 @@ def test_validate_resume_state_validates_required_keys(): def test_validate_resume_state_validates_component_consistency(): state = { "input_data": {}, - "breakpoint": {"component": "comp1", "visits": 0}, + "pipeline_breakpoint": {"component": "comp1", "visits": 0}, "pipeline_state": { "inputs": {}, "component_visits": {"comp1": 0, "comp2": 0}, @@ -344,7 +344,7 @@ def test_validate_resume_state_validates_component_consistency(): def test_validate_resume_state_validates_valid_state(): state = { "input_data": {}, - "breakpoint": {"component": "comp1", "visits": 0}, + "pipeline_breakpoint": {"component": "comp1", "visits": 0}, "pipeline_state": { "inputs": {}, "component_visits": {"comp1": 0, "comp2": 0}, @@ -357,7 +357,7 @@ def test_validate_resume_state_validates_valid_state(): def test_load_state_loads_valid_state(tmp_path): state = { "input_data": {}, - "breakpoint": {"component": "comp1", "visits": 0}, + "pipeline_breakpoint": {"component": "comp1", "visits": 0}, "pipeline_state": { "inputs": {}, "component_visits": {"comp1": 0, "comp2": 0}, @@ -374,7 +374,7 @@ def test_load_state_loads_valid_state(tmp_path): def test_load_state_handles_invalid_state(tmp_path): state = { "input_data": {}, - "breakpoint": {"component": "comp1", "visits": 0}, + "pipeline_breakpoint": {"component": "comp1", "visits": 0}, "pipeline_state": { "inputs": {}, "component_visits": {"comp1": 0, "comp2": 0}, diff --git a/test/core/pipeline/test_pipeline_breakpoints_answer_joiner.py b/test/core/pipeline/test_pipeline_breakpoints_answer_joiner.py index fbf65f27..a5a8485e 100644 --- a/test/core/pipeline/test_pipeline_breakpoints_answer_joiner.py +++ b/test/core/pipeline/test_pipeline_breakpoints_answer_joiner.py @@ -109,7 +109,7 @@ def test_pipeline_breakpoints_answer_joiner(self, answer_join_pipeline, output_d } try: - _ = answer_join_pipeline.run(data, breakpoints={(component, 0)}, debug_path=str(output_directory)) + _ = answer_join_pipeline.run(data, pipeline_breakpoint=(component, 0), debug_path=str(output_directory)) except PipelineBreakpointException as e: pass diff --git a/test/core/pipeline/test_pipeline_breakpoints_branch_joiner.py b/test/core/pipeline/test_pipeline_breakpoints_branch_joiner.py index 84cd8b15..5d23c2b6 100644 --- a/test/core/pipeline/test_pipeline_breakpoints_branch_joiner.py +++ b/test/core/pipeline/test_pipeline_breakpoints_branch_joiner.py @@ -106,7 +106,7 @@ def test_pipeline_breakpoints_branch_joiner(self, branch_joiner_pipeline, output } try: - _ = branch_joiner_pipeline.run(data, breakpoints={(component, 0)}, debug_path=str(output_directory)) + _ = branch_joiner_pipeline.run(data, pipeline_breakpoint=(component, 0), debug_path=str(output_directory)) except PipelineBreakpointException as e: pass diff --git a/test/core/pipeline/test_pipeline_breakpoints_list_joiner.py b/test/core/pipeline/test_pipeline_breakpoints_list_joiner.py index c44df13c..cade6bf6 100644 --- a/test/core/pipeline/test_pipeline_breakpoints_list_joiner.py +++ b/test/core/pipeline/test_pipeline_breakpoints_list_joiner.py @@ -120,7 +120,7 @@ def test_list_joiner_pipeline(self, list_joiner_pipeline, output_directory, comp } try: - _ = list_joiner_pipeline.run(data, breakpoints={(component, 0)}, debug_path=str(output_directory)) + _ = list_joiner_pipeline.run(data, pipeline_breakpoint=(component, 0), debug_path=str(output_directory)) except PipelineBreakpointException as e: pass diff --git a/test/core/pipeline/test_pipeline_breakpoints_loops.py b/test/core/pipeline/test_pipeline_breakpoints_loops.py index fce2e735..cb840f9d 100644 --- a/test/core/pipeline/test_pipeline_breakpoints_loops.py +++ b/test/core/pipeline/test_pipeline_breakpoints_loops.py @@ -200,7 +200,7 @@ def test_pipeline_breakpoints_validation_loop(self, validation_loop_pipeline, ou } try: - _ = validation_loop_pipeline.run(data, breakpoints={(component, 0)}, debug_path=str(output_directory)) + _ = validation_loop_pipeline.run(data, pipeline_breakpoint=(component, 0), debug_path=str(output_directory)) except PipelineBreakpointException: pass diff --git a/test/core/pipeline/test_pipeline_breakpoints_rag_hybrid.py b/test/core/pipeline/test_pipeline_breakpoints_rag_hybrid.py index f0570063..c3996822 100644 --- a/test/core/pipeline/test_pipeline_breakpoints_rag_hybrid.py +++ b/test/core/pipeline/test_pipeline_breakpoints_rag_hybrid.py @@ -274,7 +274,7 @@ def test_pipeline_breakpoints_hybrid_rag( } try: - _ = hybrid_rag_pipeline.run(data, breakpoints={(component, 0)}, debug_path=str(output_directory)) + _ = hybrid_rag_pipeline.run(data, pipeline_breakpoint=(component, 0), debug_path=str(output_directory)) except PipelineBreakpointException as e: pass diff --git a/test/core/pipeline/test_pipeline_breakpoints_string_joiner.py b/test/core/pipeline/test_pipeline_breakpoints_string_joiner.py index 25957718..5420d53c 100644 --- a/test/core/pipeline/test_pipeline_breakpoints_string_joiner.py +++ b/test/core/pipeline/test_pipeline_breakpoints_string_joiner.py @@ -49,7 +49,7 @@ def test_string_joiner_pipeline(self, string_joiner_pipeline, output_directory, data = {"prompt_builder_1": {"query": string_1}, "prompt_builder_2": {"query": string_2}} try: - _ = string_joiner_pipeline.run(data, breakpoints={(component, 0)}, debug_path=str(output_directory)) + _ = string_joiner_pipeline.run(data, pipeline_breakpoint=(component, 0), debug_path=str(output_directory)) except PipelineBreakpointException as e: pass From a02feb3818e64926ac776bb9d8fd42c649d4c4a5 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Thu, 12 Jun 2025 09:51:58 +0200 Subject: [PATCH 17/30] PR comments --- haystack_experimental/core/pipeline/base.py | 1 + haystack_experimental/core/pipeline/pipeline.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/haystack_experimental/core/pipeline/base.py b/haystack_experimental/core/pipeline/base.py index a832cc13..c4e41b3a 100644 --- a/haystack_experimental/core/pipeline/base.py +++ b/haystack_experimental/core/pipeline/base.py @@ -994,6 +994,7 @@ def _consume_component_inputs( # if we are resuming a component, the inputs are already consumed, so we just return the first input if is_resume: consumed_inputs[socket_name] = socket_inputs[0] + continue elif socket_inputs: if not socket.is_variadic: # We only care about the first input provided to the socket. diff --git a/haystack_experimental/core/pipeline/pipeline.py b/haystack_experimental/core/pipeline/pipeline.py index 58a1907c..7e32188c 100644 --- a/haystack_experimental/core/pipeline/pipeline.py +++ b/haystack_experimental/core/pipeline/pipeline.py @@ -298,7 +298,7 @@ def run( # noqa: PLR0915, PLR0912 # Scenario 2: pipeline_breakpoint is provided to stop the pipeline at # a specific component and visit count - if validated_breakpoint: + if validated_breakpoint is not None: breakpoint_component, visit_count = validated_breakpoint breakpoint_triggered = bool( breakpoint_component == component_name and visit_count == component_visits[component_name] From 3663984b7977e64d28fd52b7d2985dde3d8cde1b Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Thu, 12 Jun 2025 10:08:33 +0200 Subject: [PATCH 18/30] Fix merge --- haystack_experimental/core/pipeline/base.py | 2 +- haystack_experimental/core/pipeline/pipeline.py | 16 ++++------------ 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/haystack_experimental/core/pipeline/base.py b/haystack_experimental/core/pipeline/base.py index cf91c678..8f19f066 100644 --- a/haystack_experimental/core/pipeline/base.py +++ b/haystack_experimental/core/pipeline/base.py @@ -35,7 +35,7 @@ def _consume_component_inputs( if is_resume: consumed_inputs[socket_name] = socket_inputs[0] continue - elif socket_inputs: + if socket_inputs: if not socket.is_variadic: # We only care about the first input provided to the socket. consumed_inputs[socket_name] = socket_inputs[0] diff --git a/haystack_experimental/core/pipeline/pipeline.py b/haystack_experimental/core/pipeline/pipeline.py index f006d14d..f9208457 100644 --- a/haystack_experimental/core/pipeline/pipeline.py +++ b/haystack_experimental/core/pipeline/pipeline.py @@ -12,8 +12,6 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from haystack import Answer, Document, ExtractedAnswer, logging, tracing -from haystack.core.component import Component - from haystack.core.pipeline.base import ComponentPriority from haystack.core.pipeline.pipeline import Pipeline as HaystackPipeline from haystack.dataclasses import ChatMessage, SparseEmbedding @@ -23,10 +21,8 @@ PipelineBreakpointException, PipelineInvalidResumeStateError, ) - -from haystack_experimental.dataclasses import GeneratedAnswer from haystack_experimental.core.pipeline.base import PipelineBase - +from haystack_experimental.dataclasses import GeneratedAnswer logger = logging.getLogger(__name__) @@ -228,9 +224,7 @@ def run( # noqa: PLR0915, PLR0912 component_name, component_visits[component_name] ) - is_resume = bool( - resume_state and resume_state["pipeline_breakpoint"]["component"] == component_name - ) + is_resume = bool(resume_state and resume_state["pipeline_breakpoint"]["component"] == component_name) component_inputs = self._consume_component_inputs( component_name=component_name, component=component, inputs=inputs, is_resume=is_resume ) @@ -292,7 +286,6 @@ def run( # noqa: PLR0915, PLR0912 results=pipeline_outputs, ) - component_outputs = self._run_component( component_name=component_name, component=component, @@ -311,7 +304,6 @@ def run( # noqa: PLR0915, PLR0912 include_outputs_from=include_outputs_from, ) - if component_pipeline_outputs: pipeline_outputs[component_name] = deepcopy(component_pipeline_outputs) if self._is_queue_stale(priority_queue): @@ -338,8 +330,8 @@ def inject_resume_state_into_graph(self, resume_state): component_visits = resume_state["pipeline_state"]["component_visits"] ordered_component_names = resume_state["pipeline_state"]["ordered_component_names"] msg = ( - f"Resuming pipeline from {resume_state['breakpoint']['component']} " - f"visit count {resume_state['breakpoint']['visits']}" + f"Resuming pipeline from {resume_state['pipeline_breakpoint']['component']} " + f"visit count {resume_state['pipeline_breakpoint']['visits']}" ) logger.info(msg) return component_visits, data, resume_state, ordered_component_names From 56e6c5bcddd377c37636393cda420779b0dc1b06 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Thu, 12 Jun 2025 11:01:36 +0200 Subject: [PATCH 19/30] Remove init params --- .../core/pipeline/pipeline.py | 38 +++++-------------- 1 file changed, 9 insertions(+), 29 deletions(-) diff --git a/haystack_experimental/core/pipeline/pipeline.py b/haystack_experimental/core/pipeline/pipeline.py index f9208457..143f756d 100644 --- a/haystack_experimental/core/pipeline/pipeline.py +++ b/haystack_experimental/core/pipeline/pipeline.py @@ -252,23 +252,8 @@ def run( # noqa: PLR0915, PLR0912 ) if breakpoint_triggered: state_inputs_serialised = deepcopy(inputs) - # we store the init params for the component with pipeline_breakpoint in debug state - # this is helpful for retaining the state of the component and manual debugging state_inputs_serialised[component_name] = deepcopy(component_inputs) - # we use dict instead of to_dict() because it strips away class types of init params - init_params = { - key: value - for key, value in component["instance"].__dict__.items() - if not key.startswith("__") - } - - if "_template_string" in init_params: - init_params["template"] = init_params["_template_string"] - init_params.pop("_template_string") - - state_inputs_serialised[component_name]["init_parameters"] = init_params # type: ignore[assignment] - Pipeline._save_state( inputs=state_inputs_serialised, component_name=str(component_name), @@ -475,10 +460,15 @@ def _save_state( ordered_component_names: Optional[List[str]] = None, ) -> Dict[str, Any]: """ - If a debug_path is given it saves the JSON state of the pipeline at a given component visit count in a file. - - If debug_path is not given, it returns the JSON state as a dictionary without saving it to a file. - + Save the pipeline state to a file. + + :param inputs: The current pipeline state inputs. + :param component_name: The name of the component that triggered the breakpoint. + :param component_visits: The visit count of the component that triggered the breakpoint. + :param callback_fun: A function to call with the saved state. + :param debug_path: The path to save the state to. + :param original_input_data: The original input data. + :param ordered_component_names: The ordered component names. :raises: Exception: If the debug_path is not a string or a Path object, or if saving the JSON state fails. @@ -487,16 +477,6 @@ def _save_state( """ dt = datetime.now() - # we store a input params passed during init() in the saved state - # this is helpful for retaining the state of the component and manual debugging - for value in inputs.values(): - if "init_parameters" not in value: - continue - init_params = value.pop("init_parameters") - for k, v in value.items(): - if k in init_params and not v: - value[k] = _serialize_component_input(init_params[k]) - state = { "input_data": _serialize_component_input(original_input_data), # original input data "timestamp": dt.isoformat(), From 9ae69ffef3dfbc57abeba589de60db11ae1d3faf Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Thu, 12 Jun 2025 11:45:19 +0200 Subject: [PATCH 20/30] Create breakpoint.py --- .../core/pipeline/breakpoint.py | 317 ++++++++++++++++++ .../core/pipeline/pipeline.py | 314 +---------------- test/conftest.py | 3 +- test/core/pipeline/test_breakpoint.py | 291 ++++++++++++++++ test/core/pipeline/test_pipeline.py | 278 --------------- ...test_pipeline_breakpoints_branch_joiner.py | 3 +- .../test_pipeline_breakpoints_loops.py | 3 +- 7 files changed, 621 insertions(+), 588 deletions(-) create mode 100644 haystack_experimental/core/pipeline/breakpoint.py create mode 100644 test/core/pipeline/test_breakpoint.py diff --git a/haystack_experimental/core/pipeline/breakpoint.py b/haystack_experimental/core/pipeline/breakpoint.py new file mode 100644 index 00000000..508cecd5 --- /dev/null +++ b/haystack_experimental/core/pipeline/breakpoint.py @@ -0,0 +1,317 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +# pylint: disable=too-many-return-statements, too-many-positional-arguments + +import json +from datetime import datetime +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +from haystack import Answer, Document, ExtractedAnswer, logging +from haystack.dataclasses import ChatMessage, SparseEmbedding + +from haystack_experimental.core.errors import PipelineInvalidResumeStateError +from haystack_experimental.dataclasses import GeneratedAnswer + +logger = logging.getLogger(__name__) + + +def _validate_breakpoint(pipeline_breakpoint: Tuple[str, Optional[int]], graph: Any) -> Tuple[str, int]: + """ + Validates the pipeline_breakpoint passed to the pipeline. + + Make sure they are all valid components registered in the pipeline, + If the visit is not given, it is assumed to be 0, it will break on the first visit. + + :param pipeline_breakpoint: Tuple of component name and visit count at which the pipeline should stop. + :returns: + Tuple of valid pipeline_breakpoint. + """ + + if pipeline_breakpoint and pipeline_breakpoint[0] not in graph.nodes: + raise ValueError(f"pipeline_breakpoint {pipeline_breakpoint} is not a registered component in the pipeline") + valid_breakpoint: Tuple[str, int] = ( + (pipeline_breakpoint[0], 0 if pipeline_breakpoint[1] is None else pipeline_breakpoint[1]) + if pipeline_breakpoint + else None + ) + return valid_breakpoint + + +def _validate_pipeline_state(resume_state: Dict[str, Any], graph: Any) -> None: + """ + Validates that the resume_state contains valid configuration for the current pipeline. + + Raises a PipelineRuntimeError if any component is missing or if the state structure is invalid. + + :param resume_state: The saved state to validate. + """ + + pipeline_state = resume_state["pipeline_state"] + valid_components = set(graph.nodes.keys()) + + # Check if the ordered_component_names are valid components in the pipeline + missing_ordered = set(pipeline_state["ordered_component_names"]) - valid_components + if missing_ordered: + raise PipelineInvalidResumeStateError( + f"Invalid resume state: components {missing_ordered} in 'ordered_component_names' " + f"are not part of the current pipeline." + ) + + # Check if the input_data is valid components in the pipeline + missing_input = set(resume_state["input_data"].keys()) - valid_components + if missing_input: + raise PipelineInvalidResumeStateError( + f"Invalid resume state: components {missing_input} in 'input_data' are not part of the current pipeline." + ) + + # Validate 'component_visits' + missing_visits = set(pipeline_state["component_visits"].keys()) - valid_components + if missing_visits: + raise PipelineInvalidResumeStateError( + f"Invalid resume state: components {missing_visits} in 'component_visits' " + f"are not part of the current pipeline." + ) + + logger.info( + f"Resuming pipeline from component: {resume_state['pipeline_breakpoint']['component']} " + f"(visit {resume_state['pipeline_breakpoint']['visits']})" + ) + + +def _validate_resume_state(state: Dict[str, Any]) -> None: + """ + Validates the loaded pipeline state. + + Ensures that the state contains required keys: "input_data", "pipeline_breakpoint", and "pipeline_state". + + Raises: + ValueError: If required keys are missing or the component sets are inconsistent. + """ + + # top-level state has all required keys + required_top_keys = {"input_data", "pipeline_breakpoint", "pipeline_state"} + missing_top = required_top_keys - state.keys() + if missing_top: + raise ValueError(f"Invalid state file: missing required keys {missing_top}") + + # pipeline_state has the necessary keys + pipeline_state = state["pipeline_state"] + required_pipeline_keys = {"inputs", "component_visits", "ordered_component_names"} + missing_pipeline = required_pipeline_keys - pipeline_state.keys() + if missing_pipeline: + raise ValueError(f"Invalid pipeline_state: missing required keys {missing_pipeline}") + + # component_visits and ordered_component_names must be consistent + components_in_state = set(pipeline_state["component_visits"].keys()) + components_in_order = set(pipeline_state["ordered_component_names"]) + + if components_in_state != components_in_order: + raise ValueError( + f"Inconsistent state: components in pipeline_state['component_visits'] {components_in_state} " + f"do not match components in ordered_component_names {components_in_order}" + ) + + logger.info("Passed resume state validated successfully.") + + +def load_state(file_path: Union[str, Path]) -> Dict[str, Any]: + """ + Load a saved pipeline state. + + :param file_path: Path to the state file + :returns: + Dict containing the loaded state + """ + + file_path = Path(file_path) + + try: + with open(file_path, "r", encoding="utf-8") as f: + state = json.load(f) + except FileNotFoundError: + raise FileNotFoundError(f"File not found: {file_path}") + except json.JSONDecodeError as e: + raise json.JSONDecodeError(f"Invalid JSON file {file_path}: {str(e)}", e.doc, e.pos) + except IOError as e: + raise IOError(f"Error reading {file_path}: {str(e)}") + + try: + _validate_resume_state(state=state) + except ValueError as e: + raise ValueError(f"Invalid pipeline state from {file_path}: {str(e)}") + + logger.info(f"Successfully loaded pipeline state from: {file_path}") + return state + + +def _save_state( + inputs: Dict[str, Any], + component_name: str, + component_visits: Dict[str, int], + callback_fun: Optional[Callable[..., Any]] = None, + debug_path: Optional[Union[str, Path]] = None, + original_input_data: Optional[Dict[str, Any]] = None, + ordered_component_names: Optional[List[str]] = None, +) -> Dict[str, Any]: + """ + Save the pipeline state to a file. + + :param inputs: The current pipeline state inputs. + :param component_name: The name of the component that triggered the breakpoint. + :param component_visits: The visit count of the component that triggered the breakpoint. + :param callback_fun: A function to call with the saved state. + :param debug_path: The path to save the state to. + :param original_input_data: The original input data. + :param ordered_component_names: The ordered component names. + :raises: + Exception: If the debug_path is not a string or a Path object, or if saving the JSON state fails. + + :returns: + The saved state dictionary + """ + dt = datetime.now() + + state = { + "input_data": _serialize_component_input(original_input_data), # original input data + "timestamp": dt.isoformat(), + "pipeline_breakpoint": {"component": component_name, "visits": component_visits[component_name]}, + "pipeline_state": { + "inputs": _serialize_component_input(inputs), # current pipeline state inputs + "component_visits": component_visits, + "ordered_component_names": ordered_component_names, + }, + } + + if not debug_path: + return state + + debug_path = Path(debug_path) if isinstance(debug_path, str) else debug_path + if not isinstance(debug_path, Path): + raise ValueError("Debug path must be a string or a Path object.") + debug_path.mkdir(exist_ok=True) + file_name = Path(f"{component_name}_{dt.strftime('%Y_%m_%d_%H_%M_%S')}.json") + + try: + with open(debug_path / file_name, "w") as f_out: + json.dump(state, f_out, indent=2) + logger.info(f"Pipeline state saved at: {file_name}") + + # pass the state to some user-defined callback function + if callback_fun is not None: + callback_fun(state) + + return state + + except Exception as e: + logger.error(f"Failed to save pipeline state: {str(e)}") + raise + + +def _deserialize_component_input(value: Any) -> Any: # noqa: PLR0911 + """ + Tries to deserialize any type of input that can be passed to as input to a pipeline component. + + For primitive values, it returns the value as is, but for complex types, it tries to deserialize them. + """ + + # None or primitive types are returned as is + if not value or isinstance(value, (str, int, float, bool)): + return value + + # list of primitive types are returned as is + if isinstance(value, list) and all(isinstance(i, (str, int, float, bool)) for i in value): + return value + + if isinstance(value, list): + # list of lists are called recursively + if all(isinstance(i, list) for i in value): + return [_deserialize_component_input(i) for i in value] + # list of dicts are called recursively + if all(isinstance(i, dict) for i in value): + return [_deserialize_component_input(i) for i in value] + + # Define the mapping of types to their deserialization functions + _type_deserializers = { + "Answer": Answer.from_dict, + "ChatMessage": ChatMessage.from_dict, + "Document": Document.from_dict, + "ExtractedAnswer": ExtractedAnswer.from_dict, + "GeneratedAnswer": GeneratedAnswer.from_dict, + "SparseEmbedding": SparseEmbedding.from_dict, + } + + # check if the dictionary has a "_type" key and if it's a known type + if isinstance(value, dict): + if "_type" in value: + type_name = value.pop("_type") + if type_name in _type_deserializers: + return _type_deserializers[type_name](value) + + # If not a known type, recursively deserialize each item in the dictionary + return {k: _deserialize_component_input(v) for k, v in value.items()} + + return value + + +def _transform_json_structure(data: Union[Dict[str, Any], List[Any], Any]) -> Any: + """ + Transforms a JSON structure by removing the 'sender' key and moving the 'value' to the top level. + + For example: + "key": [{"sender": null, "value": "some value"}] -> "key": "some value" + + :param data: The JSON structure to transform. + :returns: The transformed structure. + """ + if isinstance(data, dict): + # If this dict has both 'sender' and 'value', return just the value + if "value" in data and "sender" in data: + return data["value"] + # Otherwise, recursively process each key-value pair + return {k: _transform_json_structure(v) for k, v in data.items()} + + if isinstance(data, list): + # First, transform each item in the list. + transformed = [_transform_json_structure(item) for item in data] + # If the original list has exactly one element and that element was a dict + # with 'sender' and 'value', then unwrap the list. + if len(data) == 1 and isinstance(data[0], dict) and "value" in data[0] and "sender" in data[0]: + return transformed[0] + return transformed + + # For other data types, just return the value as is. + return data + + +def _serialize_component_input(value: Any) -> Any: + """ + Serializes, so it can be saved to a file, any type of input to a pipeline component. + + :param value: The value to serialize. + :returns: The serialized value that can be saved to a file. + """ + value = _transform_json_structure(value) + if hasattr(value, "to_dict") and callable(getattr(value, "to_dict")): + serialized_value = value.to_dict() + serialized_value["_type"] = value.__class__.__name__ + return serialized_value + + # this is a hack to serialize inputs that don't have a to_dict + elif hasattr(value, "__dict__"): + return { + "_type": value.__class__.__name__, + "attributes": value.__dict__, + } + + # recursively serialize all inputs in a dict + elif isinstance(value, dict): + return {k: _serialize_component_input(v) for k, v in value.items()} + + # recursively serialize all inputs in lists or tuples + elif isinstance(value, list): + return [_serialize_component_input(item) for item in value] + + return value diff --git a/haystack_experimental/core/pipeline/pipeline.py b/haystack_experimental/core/pipeline/pipeline.py index 143f756d..c8fe6abc 100644 --- a/haystack_experimental/core/pipeline/pipeline.py +++ b/haystack_experimental/core/pipeline/pipeline.py @@ -5,16 +5,13 @@ # pylint: disable=too-many-return-statements, too-many-positional-arguments -import json from copy import deepcopy -from datetime import datetime from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Optional, Set, Tuple, Union -from haystack import Answer, Document, ExtractedAnswer, logging, tracing +from haystack import logging, tracing from haystack.core.pipeline.base import ComponentPriority from haystack.core.pipeline.pipeline import Pipeline as HaystackPipeline -from haystack.dataclasses import ChatMessage, SparseEmbedding from haystack.telemetry import pipeline_running from haystack_experimental.core.errors import ( @@ -22,7 +19,8 @@ PipelineInvalidResumeStateError, ) from haystack_experimental.core.pipeline.base import PipelineBase -from haystack_experimental.dataclasses import GeneratedAnswer + +from .breakpoint import _deserialize_component_input, _save_state, _validate_breakpoint, _validate_pipeline_state logger = logging.getLogger(__name__) @@ -154,7 +152,7 @@ def run( # noqa: PLR0915, PLR0912 ) # make sure pipeline_breakpoint is valid and have a default visit count - validated_breakpoint = self._validate_breakpoint(pipeline_breakpoint) if pipeline_breakpoint else None + validated_breakpoint = _validate_breakpoint(pipeline_breakpoint, self.graph) if pipeline_breakpoint else None # TODO: Remove this warmup once we can check reliably whether a component has been warmed up or not # As of now it's here to make sure we don't have failing tests that assume warm_up() is called in run() @@ -254,7 +252,7 @@ def run( # noqa: PLR0915, PLR0912 state_inputs_serialised = deepcopy(inputs) state_inputs_serialised[component_name] = deepcopy(component_inputs) - Pipeline._save_state( + _save_state( inputs=state_inputs_serialised, component_name=str(component_name), component_visits=component_visits, @@ -310,7 +308,7 @@ def inject_resume_state_into_graph(self, resume_state): if not resume_state: raise PipelineInvalidResumeStateError("Cannot inject resume state: resume_state is None") - self._validate_pipeline_state(resume_state) + _validate_pipeline_state(resume_state, graph=self.graph) data = self._prepare_component_input_data(resume_state["pipeline_state"]["inputs"]) component_visits = resume_state["pipeline_state"]["component_visits"] ordered_component_names = resume_state["pipeline_state"]["ordered_component_names"] @@ -320,301 +318,3 @@ def inject_resume_state_into_graph(self, resume_state): ) logger.info(msg) return component_visits, data, resume_state, ordered_component_names - - def _validate_breakpoint(self, pipeline_breakpoint: Tuple[str, Optional[int]]) -> Tuple[str, int]: - """ - Validates the pipeline_breakpoint passed to the pipeline. - - Make sure they are all valid components registered in the pipeline, - If the visit is not given, it is assumed to be 0, it will break on the first visit. - - :param pipeline_breakpoint: Tuple of component name and visit count at which the pipeline should stop. - :returns: - Tuple of valid pipeline_breakpoint. - """ - - if pipeline_breakpoint and pipeline_breakpoint[0] not in self.graph.nodes: - raise ValueError(f"pipeline_breakpoint {pipeline_breakpoint} is not a registered component in the pipeline") - valid_breakpoint: Tuple[str, int] = ( - (pipeline_breakpoint[0], 0 if pipeline_breakpoint[1] is None else pipeline_breakpoint[1]) - if pipeline_breakpoint - else None - ) - return valid_breakpoint - - def _validate_pipeline_state(self, resume_state: Dict[str, Any]) -> None: - """ - Validates that the resume_state contains valid configuration for the current pipeline. - - Raises a PipelineRuntimeError if any component is missing or if the state structure is invalid. - - :param resume_state: The saved state to validate. - """ - - pipeline_state = resume_state["pipeline_state"] - valid_components = set(self.graph.nodes.keys()) - - # Check if the ordered_component_names are valid components in the pipeline - missing_ordered = set(pipeline_state["ordered_component_names"]) - valid_components - if missing_ordered: - raise PipelineInvalidResumeStateError( - f"Invalid resume state: components {missing_ordered} in 'ordered_component_names' " - f"are not part of the current pipeline." - ) - - # Check if the input_data is valid components in the pipeline - missing_input = set(resume_state["input_data"].keys()) - valid_components - if missing_input: - raise PipelineInvalidResumeStateError( - f"Invalid resume state: components {missing_input} in 'input_data' " - f"are not part of the current pipeline." - ) - - # Validate 'component_visits' - missing_visits = set(pipeline_state["component_visits"].keys()) - valid_components - if missing_visits: - raise PipelineInvalidResumeStateError( - f"Invalid resume state: components {missing_visits} in 'component_visits' " - f"are not part of the current pipeline." - ) - - logger.info( - f"Resuming pipeline from component: {resume_state['pipeline_breakpoint']['component']} " - f"(visit {resume_state['pipeline_breakpoint']['visits']})" - ) - - @staticmethod - def _validate_resume_state(state: Dict[str, Any]) -> None: - """ - Validates the loaded pipeline state. - - Ensures that the state contains required keys: "input_data", "pipeline_breakpoint", and "pipeline_state". - - Raises: - ValueError: If required keys are missing or the component sets are inconsistent. - """ - - # top-level state has all required keys - required_top_keys = {"input_data", "pipeline_breakpoint", "pipeline_state"} - missing_top = required_top_keys - state.keys() - if missing_top: - raise ValueError(f"Invalid state file: missing required keys {missing_top}") - - # pipeline_state has the necessary keys - pipeline_state = state["pipeline_state"] - required_pipeline_keys = {"inputs", "component_visits", "ordered_component_names"} - missing_pipeline = required_pipeline_keys - pipeline_state.keys() - if missing_pipeline: - raise ValueError(f"Invalid pipeline_state: missing required keys {missing_pipeline}") - - # component_visits and ordered_component_names must be consistent - components_in_state = set(pipeline_state["component_visits"].keys()) - components_in_order = set(pipeline_state["ordered_component_names"]) - - if components_in_state != components_in_order: - raise ValueError( - f"Inconsistent state: components in pipeline_state['component_visits'] {components_in_state} " - f"do not match components in ordered_component_names {components_in_order}" - ) - - logger.info("Passed resume state validated successfully.") - - @staticmethod - def load_state(file_path: Union[str, Path]) -> Dict[str, Any]: - """ - Load a saved pipeline state. - - :param file_path: Path to the state file - :returns: - Dict containing the loaded state - """ - - file_path = Path(file_path) - - try: - with open(file_path, "r", encoding="utf-8") as f: - state = json.load(f) - except FileNotFoundError: - raise FileNotFoundError(f"File not found: {file_path}") - except json.JSONDecodeError as e: - raise json.JSONDecodeError(f"Invalid JSON file {file_path}: {str(e)}", e.doc, e.pos) - except IOError as e: - raise IOError(f"Error reading {file_path}: {str(e)}") - - try: - Pipeline._validate_resume_state(state=state) - except ValueError as e: - raise ValueError(f"Invalid pipeline state from {file_path}: {str(e)}") - - logger.info(f"Successfully loaded pipeline state from: {file_path}") - return state - - @staticmethod - def _save_state( - inputs: Dict[str, Any], - component_name: str, - component_visits: Dict[str, int], - callback_fun: Optional[Callable[..., Any]] = None, - debug_path: Optional[Union[str, Path]] = None, - original_input_data: Optional[Dict[str, Any]] = None, - ordered_component_names: Optional[List[str]] = None, - ) -> Dict[str, Any]: - """ - Save the pipeline state to a file. - - :param inputs: The current pipeline state inputs. - :param component_name: The name of the component that triggered the breakpoint. - :param component_visits: The visit count of the component that triggered the breakpoint. - :param callback_fun: A function to call with the saved state. - :param debug_path: The path to save the state to. - :param original_input_data: The original input data. - :param ordered_component_names: The ordered component names. - :raises: - Exception: If the debug_path is not a string or a Path object, or if saving the JSON state fails. - - :returns: - The saved state dictionary - """ - dt = datetime.now() - - state = { - "input_data": _serialize_component_input(original_input_data), # original input data - "timestamp": dt.isoformat(), - "pipeline_breakpoint": {"component": component_name, "visits": component_visits[component_name]}, - "pipeline_state": { - "inputs": _serialize_component_input(inputs), # current pipeline state inputs - "component_visits": component_visits, - "ordered_component_names": ordered_component_names, - }, - } - - if not debug_path: - return state - - debug_path = Path(debug_path) if isinstance(debug_path, str) else debug_path - if not isinstance(debug_path, Path): - raise ValueError("Debug path must be a string or a Path object.") - debug_path.mkdir(exist_ok=True) - file_name = Path(f"{component_name}_{dt.strftime('%Y_%m_%d_%H_%M_%S')}.json") - - try: - with open(debug_path / file_name, "w") as f_out: - json.dump(state, f_out, indent=2) - logger.info(f"Pipeline state saved at: {file_name}") - - # pass the state to some user-defined callback function - if callback_fun is not None: - callback_fun(state) - - return state - - except Exception as e: - logger.error(f"Failed to save pipeline state: {str(e)}") - raise - - -def _deserialize_component_input(value: Any) -> Any: # noqa: PLR0911 - """ - Tries to deserialize any type of input that can be passed to as input to a pipeline component. - - For primitive values, it returns the value as is, but for complex types, it tries to deserialize them. - """ - - # None or primitive types are returned as is - if not value or isinstance(value, (str, int, float, bool)): - return value - - # list of primitive types are returned as is - if isinstance(value, list) and all(isinstance(i, (str, int, float, bool)) for i in value): - return value - - if isinstance(value, list): - # list of lists are called recursively - if all(isinstance(i, list) for i in value): - return [_deserialize_component_input(i) for i in value] - # list of dicts are called recursively - if all(isinstance(i, dict) for i in value): - return [_deserialize_component_input(i) for i in value] - - # Define the mapping of types to their deserialization functions - _type_deserializers = { - "Answer": Answer.from_dict, - "ChatMessage": ChatMessage.from_dict, - "Document": Document.from_dict, - "ExtractedAnswer": ExtractedAnswer.from_dict, - "GeneratedAnswer": GeneratedAnswer.from_dict, - "SparseEmbedding": SparseEmbedding.from_dict, - } - - # check if the dictionary has a "_type" key and if it's a known type - if isinstance(value, dict): - if "_type" in value: - type_name = value.pop("_type") - if type_name in _type_deserializers: - return _type_deserializers[type_name](value) - - # If not a known type, recursively deserialize each item in the dictionary - return {k: _deserialize_component_input(v) for k, v in value.items()} - - return value - - -def _transform_json_structure(data: Union[Dict[str, Any], List[Any], Any]) -> Any: - """ - Transforms a JSON structure by removing the 'sender' key and moving the 'value' to the top level. - - For example: - "key": [{"sender": null, "value": "some value"}] -> "key": "some value" - - :param data: The JSON structure to transform. - :returns: The transformed structure. - """ - if isinstance(data, dict): - # If this dict has both 'sender' and 'value', return just the value - if "value" in data and "sender" in data: - return data["value"] - # Otherwise, recursively process each key-value pair - return {k: _transform_json_structure(v) for k, v in data.items()} - - if isinstance(data, list): - # First, transform each item in the list. - transformed = [_transform_json_structure(item) for item in data] - # If the original list has exactly one element and that element was a dict - # with 'sender' and 'value', then unwrap the list. - if len(data) == 1 and isinstance(data[0], dict) and "value" in data[0] and "sender" in data[0]: - return transformed[0] - return transformed - - # For other data types, just return the value as is. - return data - - -def _serialize_component_input(value: Any) -> Any: - """ - Serializes, so it can be saved to a file, any type of input to a pipeline component. - - :param value: The value to serialize. - :returns: The serialized value that can be saved to a file. - """ - value = _transform_json_structure(value) - if hasattr(value, "to_dict") and callable(getattr(value, "to_dict")): - serialized_value = value.to_dict() - serialized_value["_type"] = value.__class__.__name__ - return serialized_value - - # this is a hack to serialize inputs that don't have a to_dict - elif hasattr(value, "__dict__"): - return { - "_type": value.__class__.__name__, - "attributes": value.__dict__, - } - - # recursively serialize all inputs in a dict - elif isinstance(value, dict): - return {k: _serialize_component_input(v) for k, v in value.items()} - - # recursively serialize all inputs in lists or tuples - elif isinstance(value, list): - return [_serialize_component_input(item) for item in value] - - return value diff --git a/test/conftest.py b/test/conftest.py index ee8e7b0d..759101f9 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -13,6 +13,7 @@ from haystack.testing.test_utils import set_all_seeds from haystack import tracing, component +from haystack_experimental.core.pipeline.breakpoint import load_state set_all_seeds(0) @@ -75,7 +76,7 @@ def load_and_resume_pipeline_state(pipeline, output_directory: Path, component: f_name = Path(full_path).name if str(f_name).startswith(component): file_found = True - resume_state = pipeline.load_state(full_path) + resume_state = load_state(full_path) return pipeline.run(data=data, resume_state=resume_state) if not file_found: diff --git a/test/core/pipeline/test_breakpoint.py b/test/core/pipeline/test_breakpoint.py new file mode 100644 index 00000000..8da5e37b --- /dev/null +++ b/test/core/pipeline/test_breakpoint.py @@ -0,0 +1,291 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import json +import pytest + +from haystack.components.joiners import BranchJoiner +from haystack_experimental.core.pipeline.pipeline import Pipeline +from haystack_experimental.core.pipeline.breakpoint import _transform_json_structure, _serialize_component_input, _deserialize_component_input, load_state, _validate_breakpoint, _validate_resume_state + + +class TestBreakpoint: + """ + This class contains only unit tests for the breakpoint module. + """ + + def test_validate_breakpoint(self): + # simple pipeline + joiner_1 = BranchJoiner(type_=str) + joiner_2 = BranchJoiner(type_=str) + pipeline = Pipeline() + pipeline.add_component("comp1", joiner_1) + pipeline.add_component("comp2", joiner_2) + pipeline.connect("comp1", "comp2") + + # valid breakpoints + breakpoints = ("comp1", 0) + validated = _validate_breakpoint(breakpoints, pipeline.graph) + assert validated == ("comp1", 0) + + # should default to 0 + breakpoints = ("comp1", None) + validated = _validate_breakpoint(breakpoints, pipeline.graph) + assert validated == ("comp1", 0) + + # should remain as it is + breakpoints = ("comp1", -1) + validated = _validate_breakpoint(breakpoints, pipeline.graph) + assert validated == ("comp1", -1) + + # contains invalid components + breakpoints = ("comp3", 0) + with pytest.raises(ValueError, match="pipeline_breakpoint .* is not a registered component"): + _validate_breakpoint(breakpoints, pipeline.graph) + + # no breakpoints are defined + breakpoint = None + validated = _validate_breakpoint(breakpoint, pipeline.graph) + assert validated is None + + +def test_transform_json_structure_unwraps_sender_value(): + data = { + "key1": [{"sender": None, "value": "some value"}], + "key2": [{"sender": "comp1", "value": 42}], + "key3": "direct value" + } + + result = _transform_json_structure(data) + + assert result == { + "key1": "some value", + "key2": 42, + "key3": "direct value" + } + +def test_transform_json_structure_handles_nested_structures(): + data = { + "key1": [{"sender": None, "value": "value1"}], + "key2": { + "nested": [{"sender": "comp1", "value": "value2"}], + "direct": "value3" + }, + "key3": [ + [{"sender": None, "value": "value4"}], + [{"sender": "comp2", "value": "value5"}] + ] + } + + result = _transform_json_structure(data) + + assert result == { + "key1": "value1", + "key2": { + "nested": "value2", + "direct": "value3" + }, + "key3": [ + "value4", + "value5" + ] + } + +def test_serialize_component_input_handles_objects_with_to_dict(): + class TestObject: + def __init__(self, value): + self.value = value + + def to_dict(self): + return {"value": self.value} + + obj = TestObject("test") + result = _serialize_component_input(obj) + assert result == { + "_type": "TestObject", + "value": "test" + } + +def test_serialize_component_input_handles_objects_without_to_dict(): + class TestObject: + def __init__(self, value): + self.value = value + + obj = TestObject("test") + result = _serialize_component_input(obj) + assert result == { + "_type": "TestObject", + "attributes": {"value": "test"} + } + +def test_serialize_component_input_handles_nested_structures(): + class TestObject: + def __init__(self, value): + self.value = value + + def to_dict(self): + return {"value": self.value} + + obj = TestObject("test") + data = { + "key1": obj, + "key2": [obj, "string"], + "key3": {"nested": obj} + } + result = _serialize_component_input(data) + + assert result["key1"]["_type"] == "TestObject" + assert result["key2"][0]["_type"] == "TestObject" + assert result["key2"][1] == "string" + assert result["key3"]["nested"]["_type"] == "TestObject" + +def test_deserialize_component_input_handles_primitive_types(): + data = { + "string": "test", + "int": 42, + "float": 3.14, + "bool": True, + "none": None + } + result = _deserialize_component_input(data) + assert result == data + +def test_deserialize_component_input_handles_lists(): + data = { + "primitive_list": [1, 2, 3], + "mixed_list": [1, "string", True] + } + result = _deserialize_component_input(data) + assert result == data + +def test_deserialize_component_input_handles_dicts(): + data = { + "key1": "value1", + "key2": {"nested": "value2"} + } + result = _deserialize_component_input(data) + assert result == data + +def test_deserialize_component_input_handles_nested_lists(): + """Test that _deserialize_component_input handles nested lists""" + data = { + "nested_list": [[1, 2], [3, 4]], + "mixed_nested": [[1, "string"], [True, 3.14]] + } + + result = _deserialize_component_input(data) + + assert result == data + +def test_deserialize_component_input_handles_nested_dicts(): + """Test that _deserialize_component_input handles nested dictionaries""" + data = { + "key1": { + "nested1": "value1", + "nested2": { + "deep": "value2" + } + } + } + + result = _deserialize_component_input(data) + + assert result == data + +def test_deserialize_component_input_handles_empty_structures(): + """Test that _deserialize_component_input handles empty structures""" + data = { + "empty_list": [], + "empty_dict": {}, + "nested_empty": {"empty": []} + } + + result = _deserialize_component_input(data) + + assert result == data + +def test_validate_resume_state_validates_required_keys(): + state = { + "input_data": {}, + "pipeline_breakpoint": {"component": "comp1", "visits": 0} + # Missing pipeline_state + } + + with pytest.raises(ValueError, match="Invalid state file: missing required keys"): + _validate_resume_state(state) + + state = { + "input_data": {}, + "pipeline_breakpoint": {"component": "comp1", "visits": 0}, + "pipeline_state": { + "inputs": {}, + "component_visits": {} + # Missing ordered_component_names + } + } + + with pytest.raises(ValueError, match="Invalid pipeline_state: missing required keys"): + _validate_resume_state(state) + +def test_validate_resume_state_validates_component_consistency(): + state = { + "input_data": {}, + "pipeline_breakpoint": {"component": "comp1", "visits": 0}, + "pipeline_state": { + "inputs": {}, + "component_visits": {"comp1": 0, "comp2": 0}, + "ordered_component_names": ["comp1", "comp3"] # inconsistent with component_visits + } + } + + with pytest.raises(ValueError, match="Inconsistent state: components in pipeline_state"): + _validate_resume_state(state) + +def test_validate_resume_state_validates_valid_state(): + state = { + "input_data": {}, + "pipeline_breakpoint": {"component": "comp1", "visits": 0}, + "pipeline_state": { + "inputs": {}, + "component_visits": {"comp1": 0, "comp2": 0}, + "ordered_component_names": ["comp1", "comp2"] + } + } + + _validate_resume_state(state) # should not raise any exception + +def test_load_state_loads_valid_state(tmp_path): + state = { + "input_data": {}, + "pipeline_breakpoint": {"component": "comp1", "visits": 0}, + "pipeline_state": { + "inputs": {}, + "component_visits": {"comp1": 0, "comp2": 0}, + "ordered_component_names": ["comp1", "comp2"] + } + } + state_file = tmp_path / "state.json" + with open(state_file, "w") as f: + json.dump(state, f) + + loaded_state = load_state(state_file) + assert loaded_state == state + +def test_load_state_handles_invalid_state(tmp_path): + state = { + "input_data": {}, + "pipeline_breakpoint": {"component": "comp1", "visits": 0}, + "pipeline_state": { + "inputs": {}, + "component_visits": {"comp1": 0, "comp2": 0}, + "ordered_component_names": ["comp1", "comp3"] # inconsistent with component_visits + } + } + + state_file = tmp_path / "invalid_state.json" + with open(state_file, "w") as f: + json.dump(state, f) + + with pytest.raises(ValueError, match="Invalid pipeline state from"): + load_state(state_file) diff --git a/test/core/pipeline/test_pipeline.py b/test/core/pipeline/test_pipeline.py index 7280e88a..cf7559ea 100644 --- a/test/core/pipeline/test_pipeline.py +++ b/test/core/pipeline/test_pipeline.py @@ -11,9 +11,6 @@ from haystack.core.errors import PipelineRuntimeError from haystack_experimental.core.pipeline.pipeline import ( Pipeline, - _transform_json_structure, - _serialize_component_input, - _deserialize_component_input ) @@ -113,278 +110,3 @@ def test_run(self): pp.connect("joiner_1", "joiner_2") _ = pp.run({"value": "test_value"}) - - def test_validate_breakpoint(self): - # simple pipeline - joiner_1 = BranchJoiner(type_=str) - joiner_2 = BranchJoiner(type_=str) - pipeline = Pipeline() - pipeline.add_component("comp1", joiner_1) - pipeline.add_component("comp2", joiner_2) - pipeline.connect("comp1", "comp2") - - # valid breakpoints - breakpoints = ("comp1", 0) - validated = pipeline._validate_breakpoint(breakpoints) - assert validated == ("comp1", 0) - - # should default to 0 - breakpoints = ("comp1", None) - validated = pipeline._validate_breakpoint(breakpoints) - assert validated == ("comp1", 0) - - # should remain as it is - breakpoints = ("comp1", -1) - validated = pipeline._validate_breakpoint(breakpoints) - assert validated == ("comp1", -1) - - # contains invalid components - breakpoints = ("comp3", 0) - with pytest.raises(ValueError, match="pipeline_breakpoint .* is not a registered component"): - pipeline._validate_breakpoint(breakpoints) - - # no breakpoints are defined - breakpoint = None - validated = pipeline._validate_breakpoint(breakpoint) - assert validated is None - - -def test_transform_json_structure_unwraps_sender_value(): - data = { - "key1": [{"sender": None, "value": "some value"}], - "key2": [{"sender": "comp1", "value": 42}], - "key3": "direct value" - } - - result = _transform_json_structure(data) - - assert result == { - "key1": "some value", - "key2": 42, - "key3": "direct value" - } - -def test_transform_json_structure_handles_nested_structures(): - data = { - "key1": [{"sender": None, "value": "value1"}], - "key2": { - "nested": [{"sender": "comp1", "value": "value2"}], - "direct": "value3" - }, - "key3": [ - [{"sender": None, "value": "value4"}], - [{"sender": "comp2", "value": "value5"}] - ] - } - - result = _transform_json_structure(data) - - assert result == { - "key1": "value1", - "key2": { - "nested": "value2", - "direct": "value3" - }, - "key3": [ - "value4", - "value5" - ] - } - -def test_serialize_component_input_handles_objects_with_to_dict(): - class TestObject: - def __init__(self, value): - self.value = value - - def to_dict(self): - return {"value": self.value} - - obj = TestObject("test") - result = _serialize_component_input(obj) - assert result == { - "_type": "TestObject", - "value": "test" - } - -def test_serialize_component_input_handles_objects_without_to_dict(): - class TestObject: - def __init__(self, value): - self.value = value - - obj = TestObject("test") - result = _serialize_component_input(obj) - assert result == { - "_type": "TestObject", - "attributes": {"value": "test"} - } - -def test_serialize_component_input_handles_nested_structures(): - class TestObject: - def __init__(self, value): - self.value = value - - def to_dict(self): - return {"value": self.value} - - obj = TestObject("test") - data = { - "key1": obj, - "key2": [obj, "string"], - "key3": {"nested": obj} - } - result = _serialize_component_input(data) - - assert result["key1"]["_type"] == "TestObject" - assert result["key2"][0]["_type"] == "TestObject" - assert result["key2"][1] == "string" - assert result["key3"]["nested"]["_type"] == "TestObject" - -def test_deserialize_component_input_handles_primitive_types(): - data = { - "string": "test", - "int": 42, - "float": 3.14, - "bool": True, - "none": None - } - result = _deserialize_component_input(data) - assert result == data - -def test_deserialize_component_input_handles_lists(): - data = { - "primitive_list": [1, 2, 3], - "mixed_list": [1, "string", True] - } - result = _deserialize_component_input(data) - assert result == data - -def test_deserialize_component_input_handles_dicts(): - data = { - "key1": "value1", - "key2": {"nested": "value2"} - } - result = _deserialize_component_input(data) - assert result == data - -def test_deserialize_component_input_handles_nested_lists(): - """Test that _deserialize_component_input handles nested lists""" - data = { - "nested_list": [[1, 2], [3, 4]], - "mixed_nested": [[1, "string"], [True, 3.14]] - } - - result = _deserialize_component_input(data) - - assert result == data - -def test_deserialize_component_input_handles_nested_dicts(): - """Test that _deserialize_component_input handles nested dictionaries""" - data = { - "key1": { - "nested1": "value1", - "nested2": { - "deep": "value2" - } - } - } - - result = _deserialize_component_input(data) - - assert result == data - -def test_deserialize_component_input_handles_empty_structures(): - """Test that _deserialize_component_input handles empty structures""" - data = { - "empty_list": [], - "empty_dict": {}, - "nested_empty": {"empty": []} - } - - result = _deserialize_component_input(data) - - assert result == data - -def test_validate_resume_state_validates_required_keys(): - state = { - "input_data": {}, - "pipeline_breakpoint": {"component": "comp1", "visits": 0} - # Missing pipeline_state - } - - with pytest.raises(ValueError, match="Invalid state file: missing required keys"): - Pipeline._validate_resume_state(state) - - state = { - "input_data": {}, - "pipeline_breakpoint": {"component": "comp1", "visits": 0}, - "pipeline_state": { - "inputs": {}, - "component_visits": {} - # Missing ordered_component_names - } - } - - with pytest.raises(ValueError, match="Invalid pipeline_state: missing required keys"): - Pipeline._validate_resume_state(state) - -def test_validate_resume_state_validates_component_consistency(): - state = { - "input_data": {}, - "pipeline_breakpoint": {"component": "comp1", "visits": 0}, - "pipeline_state": { - "inputs": {}, - "component_visits": {"comp1": 0, "comp2": 0}, - "ordered_component_names": ["comp1", "comp3"] # inconsistent with component_visits - } - } - - with pytest.raises(ValueError, match="Inconsistent state: components in pipeline_state"): - Pipeline._validate_resume_state(state) - -def test_validate_resume_state_validates_valid_state(): - state = { - "input_data": {}, - "pipeline_breakpoint": {"component": "comp1", "visits": 0}, - "pipeline_state": { - "inputs": {}, - "component_visits": {"comp1": 0, "comp2": 0}, - "ordered_component_names": ["comp1", "comp2"] - } - } - - Pipeline._validate_resume_state(state) # should not raise any exception - -def test_load_state_loads_valid_state(tmp_path): - state = { - "input_data": {}, - "pipeline_breakpoint": {"component": "comp1", "visits": 0}, - "pipeline_state": { - "inputs": {}, - "component_visits": {"comp1": 0, "comp2": 0}, - "ordered_component_names": ["comp1", "comp2"] - } - } - state_file = tmp_path / "state.json" - with open(state_file, "w") as f: - json.dump(state, f) - - loaded_state = Pipeline.load_state(state_file) - assert loaded_state == state - -def test_load_state_handles_invalid_state(tmp_path): - state = { - "input_data": {}, - "pipeline_breakpoint": {"component": "comp1", "visits": 0}, - "pipeline_state": { - "inputs": {}, - "component_visits": {"comp1": 0, "comp2": 0}, - "ordered_component_names": ["comp1", "comp3"] # inconsistent with component_visits - } - } - - state_file = tmp_path / "invalid_state.json" - with open(state_file, "w") as f: - json.dump(state, f) - - with pytest.raises(ValueError, match="Invalid pipeline state from"): - Pipeline.load_state(state_file) diff --git a/test/core/pipeline/test_pipeline_breakpoints_branch_joiner.py b/test/core/pipeline/test_pipeline_breakpoints_branch_joiner.py index 5d23c2b6..ad2d0ded 100644 --- a/test/core/pipeline/test_pipeline_breakpoints_branch_joiner.py +++ b/test/core/pipeline/test_pipeline_breakpoints_branch_joiner.py @@ -13,6 +13,7 @@ from haystack.utils.auth import Secret from haystack_experimental.core.errors import PipelineBreakpointException from haystack_experimental.core.pipeline.pipeline import Pipeline +from haystack_experimental.core.pipeline.breakpoint import load_state from unittest.mock import patch class TestPipelineBreakpoints: @@ -116,7 +117,7 @@ def test_pipeline_breakpoints_branch_joiner(self, branch_joiner_pipeline, output f_name = Path(full_path).name if str(f_name).startswith(component): file_found = True - resume_state = Pipeline.load_state(full_path) + resume_state = load_state(full_path) result = branch_joiner_pipeline.run(data, resume_state=resume_state) assert result['validator'] break diff --git a/test/core/pipeline/test_pipeline_breakpoints_loops.py b/test/core/pipeline/test_pipeline_breakpoints_loops.py index cb840f9d..12b1197d 100644 --- a/test/core/pipeline/test_pipeline_breakpoints_loops.py +++ b/test/core/pipeline/test_pipeline_breakpoints_loops.py @@ -16,6 +16,7 @@ from haystack.utils.auth import Secret from haystack_experimental.core.errors import PipelineBreakpointException from haystack_experimental.core.pipeline.pipeline import Pipeline +from haystack_experimental.core.pipeline.breakpoint import load_state from unittest.mock import patch # Define the component input parameters @@ -210,7 +211,7 @@ def test_pipeline_breakpoints_validation_loop(self, validation_loop_pipeline, ou f_name = Path(full_path).name if str(f_name).startswith(component): file_found = True - resume_state = Pipeline.load_state(full_path) + resume_state = load_state(full_path) result = validation_loop_pipeline.run(data={}, resume_state=resume_state) # Verify the result contains valid output if "output_validator" in result and "valid_replies" in result["output_validator"]: From 3f6c6bb6086cdc8bb770829b1d0114a7a44ebc64 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Thu, 12 Jun 2025 12:53:20 +0200 Subject: [PATCH 21/30] PR comments --- haystack_experimental/core/pipeline/pipeline.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/haystack_experimental/core/pipeline/pipeline.py b/haystack_experimental/core/pipeline/pipeline.py index c8fe6abc..c3e2fcf8 100644 --- a/haystack_experimental/core/pipeline/pipeline.py +++ b/haystack_experimental/core/pipeline/pipeline.py @@ -10,14 +10,12 @@ from typing import Any, Dict, Optional, Set, Tuple, Union from haystack import logging, tracing +from haystack.core.errors import PipelineRuntimeError from haystack.core.pipeline.base import ComponentPriority from haystack.core.pipeline.pipeline import Pipeline as HaystackPipeline from haystack.telemetry import pipeline_running -from haystack_experimental.core.errors import ( - PipelineBreakpointException, - PipelineInvalidResumeStateError, -) +from haystack_experimental.core.errors import PipelineBreakpointException, PipelineInvalidResumeStateError from haystack_experimental.core.pipeline.base import PipelineBase from .breakpoint import _deserialize_component_input, _save_state, _validate_breakpoint, _validate_pipeline_state @@ -147,9 +145,11 @@ def run( # noqa: PLR0915, PLR0912 pipeline_running(self) if pipeline_breakpoint and resume_state: - logger.warning( - "pipeline_breakpoint will be ignored because it cannot be provided when resuming a pipeline.", + msg = ( + "pipeline_breakpoint and resume_state cannot be provided at the same time. " + "The pipeline run will be aborted." ) + raise PipelineRuntimeError(message=msg) # make sure pipeline_breakpoint is valid and have a default visit count validated_breakpoint = _validate_breakpoint(pipeline_breakpoint, self.graph) if pipeline_breakpoint else None From 8ee3eee1251e85553cf1326758723670935d63a4 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Thu, 12 Jun 2025 12:59:04 +0200 Subject: [PATCH 22/30] Linting --- haystack_experimental/core/pipeline/pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/haystack_experimental/core/pipeline/pipeline.py b/haystack_experimental/core/pipeline/pipeline.py index c3e2fcf8..35b34223 100644 --- a/haystack_experimental/core/pipeline/pipeline.py +++ b/haystack_experimental/core/pipeline/pipeline.py @@ -149,7 +149,7 @@ def run( # noqa: PLR0915, PLR0912 "pipeline_breakpoint and resume_state cannot be provided at the same time. " "The pipeline run will be aborted." ) - raise PipelineRuntimeError(message=msg) + raise PipelineInvalidResumeStateError(message=msg) # make sure pipeline_breakpoint is valid and have a default visit count validated_breakpoint = _validate_breakpoint(pipeline_breakpoint, self.graph) if pipeline_breakpoint else None From 65c5d716c44e7aaea9a810c504d57fe69a289629 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Thu, 12 Jun 2025 13:06:29 +0200 Subject: [PATCH 23/30] Linting --- haystack_experimental/core/pipeline/pipeline.py | 1 - 1 file changed, 1 deletion(-) diff --git a/haystack_experimental/core/pipeline/pipeline.py b/haystack_experimental/core/pipeline/pipeline.py index 35b34223..cbeea62d 100644 --- a/haystack_experimental/core/pipeline/pipeline.py +++ b/haystack_experimental/core/pipeline/pipeline.py @@ -10,7 +10,6 @@ from typing import Any, Dict, Optional, Set, Tuple, Union from haystack import logging, tracing -from haystack.core.errors import PipelineRuntimeError from haystack.core.pipeline.base import ComponentPriority from haystack.core.pipeline.pipeline import Pipeline as HaystackPipeline from haystack.telemetry import pipeline_running From 1c2b4f455ee472be2a132bcdf1c6fee5536ddfc0 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Thu, 12 Jun 2025 15:35:34 +0200 Subject: [PATCH 24/30] Update haystack_experimental/core/pipeline/breakpoint.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> --- haystack_experimental/core/pipeline/breakpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/haystack_experimental/core/pipeline/breakpoint.py b/haystack_experimental/core/pipeline/breakpoint.py index 508cecd5..6b4c03fd 100644 --- a/haystack_experimental/core/pipeline/breakpoint.py +++ b/haystack_experimental/core/pipeline/breakpoint.py @@ -22,7 +22,7 @@ def _validate_breakpoint(pipeline_breakpoint: Tuple[str, Optional[int]], graph: """ Validates the pipeline_breakpoint passed to the pipeline. - Make sure they are all valid components registered in the pipeline, + Makes sure the breakpoint contains a valid components registered in the pipeline. If the visit is not given, it is assumed to be 0, it will break on the first visit. :param pipeline_breakpoint: Tuple of component name and visit count at which the pipeline should stop. From 018dbead2e67d5ea3a602e69ee102883c20313bd Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Thu, 12 Jun 2025 15:35:50 +0200 Subject: [PATCH 25/30] Update haystack_experimental/core/pipeline/breakpoint.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> --- haystack_experimental/core/pipeline/breakpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/haystack_experimental/core/pipeline/breakpoint.py b/haystack_experimental/core/pipeline/breakpoint.py index 6b4c03fd..7aec6a41 100644 --- a/haystack_experimental/core/pipeline/breakpoint.py +++ b/haystack_experimental/core/pipeline/breakpoint.py @@ -27,7 +27,7 @@ def _validate_breakpoint(pipeline_breakpoint: Tuple[str, Optional[int]], graph: :param pipeline_breakpoint: Tuple of component name and visit count at which the pipeline should stop. :returns: - Tuple of valid pipeline_breakpoint. + Tuple of component name and visit count representing the `pipeline_breakpoint` """ if pipeline_breakpoint and pipeline_breakpoint[0] not in graph.nodes: From bb149d413080deb1ab4fa67f67ad0788ef14a998 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Thu, 12 Jun 2025 15:36:05 +0200 Subject: [PATCH 26/30] Update haystack_experimental/core/pipeline/breakpoint.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> --- haystack_experimental/core/pipeline/breakpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/haystack_experimental/core/pipeline/breakpoint.py b/haystack_experimental/core/pipeline/breakpoint.py index 7aec6a41..0f96cd2e 100644 --- a/haystack_experimental/core/pipeline/breakpoint.py +++ b/haystack_experimental/core/pipeline/breakpoint.py @@ -44,7 +44,7 @@ def _validate_pipeline_state(resume_state: Dict[str, Any], graph: Any) -> None: """ Validates that the resume_state contains valid configuration for the current pipeline. - Raises a PipelineRuntimeError if any component is missing or if the state structure is invalid. + Raises a PipelineInvalidResumeStateError if any component is missing or if the state structure is invalid. :param resume_state: The saved state to validate. """ From faaf049b07b5a5e0a61d8652bbb49fdf2fdf02e1 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Thu, 12 Jun 2025 15:36:19 +0200 Subject: [PATCH 27/30] Update haystack_experimental/core/pipeline/breakpoint.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> --- haystack_experimental/core/pipeline/breakpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/haystack_experimental/core/pipeline/breakpoint.py b/haystack_experimental/core/pipeline/breakpoint.py index 0f96cd2e..a8686e7a 100644 --- a/haystack_experimental/core/pipeline/breakpoint.py +++ b/haystack_experimental/core/pipeline/breakpoint.py @@ -81,7 +81,7 @@ def _validate_pipeline_state(resume_state: Dict[str, Any], graph: Any) -> None: ) -def _validate_resume_state(state: Dict[str, Any]) -> None: +def _validate_resume_state(resume_state: Dict[str, Any]) -> None: """ Validates the loaded pipeline state. From d5441a2fe70b890aeb49ee509cb7f4e2b421d494 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Thu, 12 Jun 2025 16:04:37 +0200 Subject: [PATCH 28/30] PR comments --- .../core/pipeline/breakpoint.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/haystack_experimental/core/pipeline/breakpoint.py b/haystack_experimental/core/pipeline/breakpoint.py index a8686e7a..ca2a0e24 100644 --- a/haystack_experimental/core/pipeline/breakpoint.py +++ b/haystack_experimental/core/pipeline/breakpoint.py @@ -11,6 +11,7 @@ from haystack import Answer, Document, ExtractedAnswer, logging from haystack.dataclasses import ChatMessage, SparseEmbedding +from networkx import MultiDiGraph from haystack_experimental.core.errors import PipelineInvalidResumeStateError from haystack_experimental.dataclasses import GeneratedAnswer @@ -18,7 +19,7 @@ logger = logging.getLogger(__name__) -def _validate_breakpoint(pipeline_breakpoint: Tuple[str, Optional[int]], graph: Any) -> Tuple[str, int]: +def _validate_breakpoint(pipeline_breakpoint: Tuple[str, Optional[int]], graph: MultiDiGraph) -> Tuple[str, int]: """ Validates the pipeline_breakpoint passed to the pipeline. @@ -40,7 +41,7 @@ def _validate_breakpoint(pipeline_breakpoint: Tuple[str, Optional[int]], graph: return valid_breakpoint -def _validate_pipeline_state(resume_state: Dict[str, Any], graph: Any) -> None: +def _validate_pipeline_state(resume_state: Dict[str, Any], graph: MultiDiGraph) -> None: """ Validates that the resume_state contains valid configuration for the current pipeline. @@ -83,9 +84,9 @@ def _validate_pipeline_state(resume_state: Dict[str, Any], graph: Any) -> None: def _validate_resume_state(resume_state: Dict[str, Any]) -> None: """ - Validates the loaded pipeline state. + Validates the loaded pipeline resume_state. - Ensures that the state contains required keys: "input_data", "pipeline_breakpoint", and "pipeline_state". + Ensures that the resume_state contains required keys: "input_data", "pipeline_breakpoint", and "pipeline_state". Raises: ValueError: If required keys are missing or the component sets are inconsistent. @@ -93,12 +94,12 @@ def _validate_resume_state(resume_state: Dict[str, Any]) -> None: # top-level state has all required keys required_top_keys = {"input_data", "pipeline_breakpoint", "pipeline_state"} - missing_top = required_top_keys - state.keys() + missing_top = required_top_keys - resume_state.keys() if missing_top: raise ValueError(f"Invalid state file: missing required keys {missing_top}") # pipeline_state has the necessary keys - pipeline_state = state["pipeline_state"] + pipeline_state = resume_state["pipeline_state"] required_pipeline_keys = {"inputs", "component_visits", "ordered_component_names"} missing_pipeline = required_pipeline_keys - pipeline_state.keys() if missing_pipeline: @@ -121,9 +122,9 @@ def load_state(file_path: Union[str, Path]) -> Dict[str, Any]: """ Load a saved pipeline state. - :param file_path: Path to the state file + :param file_path: Path to the resume_state file :returns: - Dict containing the loaded state + Dict containing the loaded resume_state. """ file_path = Path(file_path) @@ -139,7 +140,7 @@ def load_state(file_path: Union[str, Path]) -> Dict[str, Any]: raise IOError(f"Error reading {file_path}: {str(e)}") try: - _validate_resume_state(state=state) + _validate_resume_state(resume_state=state) except ValueError as e: raise ValueError(f"Invalid pipeline state from {file_path}: {str(e)}") From 073cb40cf1370b4cea106ad1b0b15ec2e1ad112c Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Thu, 12 Jun 2025 16:16:50 +0200 Subject: [PATCH 29/30] Improve docs --- haystack_experimental/core/pipeline/breakpoint.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/haystack_experimental/core/pipeline/breakpoint.py b/haystack_experimental/core/pipeline/breakpoint.py index ca2a0e24..bf25390b 100644 --- a/haystack_experimental/core/pipeline/breakpoint.py +++ b/haystack_experimental/core/pipeline/breakpoint.py @@ -152,7 +152,6 @@ def _save_state( inputs: Dict[str, Any], component_name: str, component_visits: Dict[str, int], - callback_fun: Optional[Callable[..., Any]] = None, debug_path: Optional[Union[str, Path]] = None, original_input_data: Optional[Dict[str, Any]] = None, ordered_component_names: Optional[List[str]] = None, @@ -163,7 +162,6 @@ def _save_state( :param inputs: The current pipeline state inputs. :param component_name: The name of the component that triggered the breakpoint. :param component_visits: The visit count of the component that triggered the breakpoint. - :param callback_fun: A function to call with the saved state. :param debug_path: The path to save the state to. :param original_input_data: The original input data. :param ordered_component_names: The ordered component names. @@ -171,7 +169,14 @@ def _save_state( Exception: If the debug_path is not a string or a Path object, or if saving the JSON state fails. :returns: - The saved state dictionary + The dictionary containing the state of the pipeline containing the following keys: + - input_data: The original input data passed to the pipeline. + - timestamp: The timestamp of the breakpoint. + - pipeline_breakpoint: The component name and visit count that triggered the breakpoint. + - pipeline_state: The state of the pipeline when the breakpoint was triggered containing the following keys: + - inputs: The current state of inputs for pipeline components. + - component_visits: The visit count of the components when the breakpoint was triggered. + - ordered_component_names: The order of components in the pipeline. """ dt = datetime.now() @@ -200,10 +205,6 @@ def _save_state( json.dump(state, f_out, indent=2) logger.info(f"Pipeline state saved at: {file_name}") - # pass the state to some user-defined callback function - if callback_fun is not None: - callback_fun(state) - return state except Exception as e: From cdf9022ebee8b5895b946078c021c58edf7eaf1c Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Thu, 12 Jun 2025 16:18:53 +0200 Subject: [PATCH 30/30] Improve docs --- haystack_experimental/core/pipeline/breakpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/haystack_experimental/core/pipeline/breakpoint.py b/haystack_experimental/core/pipeline/breakpoint.py index bf25390b..7267b7a7 100644 --- a/haystack_experimental/core/pipeline/breakpoint.py +++ b/haystack_experimental/core/pipeline/breakpoint.py @@ -7,7 +7,7 @@ import json from datetime import datetime from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from haystack import Answer, Document, ExtractedAnswer, logging from haystack.dataclasses import ChatMessage, SparseEmbedding