diff --git a/haystack_experimental/components/builders/answer_builder.py b/haystack_experimental/components/builders/answer_builder.py new file mode 100644 index 00000000..787e9393 --- /dev/null +++ b/haystack_experimental/components/builders/answer_builder.py @@ -0,0 +1,130 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional, Union + +from haystack import Document, component, logging +from haystack.components.builders.answer_builder import AnswerBuilder as HaystackAnswerBuilder +from haystack.dataclasses.chat_message import ChatMessage + +from haystack_experimental.dataclasses import GeneratedAnswer + +logger = logging.getLogger(__name__) + + +@component +class AnswerBuilder(HaystackAnswerBuilder): + """ + Converts a query and Generator replies into a `GeneratedAnswer` object. + + AnswerBuilder parses Generator replies using custom regular expressions. + Check out the usage example below to see how it works. + Optionally, it can also take documents and metadata from the Generator to add to the `GeneratedAnswer` object. + AnswerBuilder works with both non-chat and chat Generators. + + ### Usage example + + ```python + from haystack.components.builders import AnswerBuilder + + builder = AnswerBuilder(pattern="Answer: (.*)") + builder.run(query="What's the answer?", replies=["This is an argument. Answer: This is the answer."]) + ``` + """ + + @component.output_types(answers=List[GeneratedAnswer]) + def run( # pylint: disable=too-many-positional-arguments + self, + query: str, + replies: Union[List[str], List[ChatMessage]], + meta: Optional[List[Dict[str, Any]]] = None, + documents: Optional[List[Document]] = None, + pattern: Optional[str] = None, + reference_pattern: Optional[str] = None, + ) -> Dict[str, List[GeneratedAnswer]]: + """ + Turns the output of a Generator into `GeneratedAnswer` objects using regular expressions. + + :param query: + The input query used as the Generator prompt. + :param replies: + The output of the Generator. Can be a list of strings or a list of `ChatMessage` objects. + :param meta: + The metadata returned by the Generator. If not specified, the generated answer will contain no metadata. + :param documents: + The documents used as the Generator inputs. If specified, they are added to + the`GeneratedAnswer` objects. + If both `documents` and `reference_pattern` are specified, the documents referenced in the + Generator output are extracted from the input documents and added to the `GeneratedAnswer` objects. + :param pattern: + The regular expression pattern to extract the answer text from the Generator. + If not specified, the entire response is used as the answer. + The regular expression can have one capture group at most. + If present, the capture group text + is used as the answer. If no capture group is present, the whole match is used as the answer. + Examples: + `[^\\n]+$` finds "this is an answer" in a string "this is an argument.\\nthis is an answer". + `Answer: (.*)` finds "this is an answer" in a string + "this is an argument. Answer: this is an answer". + :param reference_pattern: + The regular expression pattern used for parsing the document references. + If not specified, no parsing is done, and all documents are referenced. + References need to be specified as indices of the input documents and start at [1]. + Example: `\\[(\\d+)\\]` finds "1" in a string "this is an answer[1]". + + :returns: A dictionary with the following keys: + - `answers`: The answers received from the output of the Generator. + """ + if not meta: + meta = [{}] * len(replies) + elif len(replies) != len(meta): + raise ValueError(f"Number of replies ({len(replies)}), and metadata ({len(meta)}) must match.") + + if pattern: + AnswerBuilder._check_num_groups_in_regex(pattern) + + pattern = pattern or self.pattern + reference_pattern = reference_pattern or self.reference_pattern + all_answers = [] + + replies_to_iterate = replies + meta_to_iterate = meta + + if self.last_message_only and replies: + replies_to_iterate = replies[-1:] + meta_to_iterate = meta[-1:] + + for reply, given_metadata in zip(replies_to_iterate, meta_to_iterate): + # Extract content from ChatMessage objects if reply is a ChatMessages, else use the string as is + if isinstance(reply, ChatMessage): + extracted_reply = reply.text or "" + else: + extracted_reply = str(reply) + extracted_metadata = reply.meta if isinstance(reply, ChatMessage) else {} + + extracted_metadata = {**extracted_metadata, **given_metadata} + extracted_metadata["all_messages"] = replies + + referenced_docs = [] + if documents: + if reference_pattern: + reference_idxs = AnswerBuilder._extract_reference_idxs(extracted_reply, reference_pattern) + else: + reference_idxs = [doc_idx for doc_idx, _ in enumerate(documents)] + + for idx in reference_idxs: + try: + referenced_docs.append(documents[idx]) + except IndexError: + logger.warning( + "Document index '{index}' referenced in Generator output is out of range. ", index=idx + 1 + ) + + answer_string = AnswerBuilder._extract_answer_string(extracted_reply, pattern) + answer = GeneratedAnswer( + data=answer_string, query=query, documents=referenced_docs, meta=extracted_metadata + ) + all_answers.append(answer) + + return {"answers": all_answers} diff --git a/haystack_experimental/core/pipeline/base.py b/haystack_experimental/core/pipeline/base.py index a729689a..8f19f066 100644 --- a/haystack_experimental/core/pipeline/base.py +++ b/haystack_experimental/core/pipeline/base.py @@ -13,7 +13,9 @@ class PipelineBase(HaystackPipelineBase): @staticmethod - def _consume_component_inputs(component_name: str, component: Dict, inputs: Dict) -> Dict[str, Any]: + def _consume_component_inputs( + component_name: str, component: Dict, inputs: Dict, is_resume: bool = False + ) -> Dict[str, Any]: """ Extracts the inputs needed to run for the component and removes them from the global inputs state. @@ -28,6 +30,11 @@ def _consume_component_inputs(component_name: str, component: Dict, inputs: Dict for socket_name, socket in component["input_sockets"].items(): socket_inputs = component_inputs.get(socket_name, []) socket_inputs = [sock["value"] for sock in socket_inputs if sock["value"] is not _NO_OUTPUT_PRODUCED] + + # if we are resuming a component, the inputs are already consumed, so we just return the first input + if is_resume: + consumed_inputs[socket_name] = socket_inputs[0] + continue if socket_inputs: if not socket.is_variadic: # We only care about the first input provided to the socket. diff --git a/haystack_experimental/core/pipeline/breakpoint.py b/haystack_experimental/core/pipeline/breakpoint.py new file mode 100644 index 00000000..7267b7a7 --- /dev/null +++ b/haystack_experimental/core/pipeline/breakpoint.py @@ -0,0 +1,319 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +# pylint: disable=too-many-return-statements, too-many-positional-arguments + +import json +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +from haystack import Answer, Document, ExtractedAnswer, logging +from haystack.dataclasses import ChatMessage, SparseEmbedding +from networkx import MultiDiGraph + +from haystack_experimental.core.errors import PipelineInvalidResumeStateError +from haystack_experimental.dataclasses import GeneratedAnswer + +logger = logging.getLogger(__name__) + + +def _validate_breakpoint(pipeline_breakpoint: Tuple[str, Optional[int]], graph: MultiDiGraph) -> Tuple[str, int]: + """ + Validates the pipeline_breakpoint passed to the pipeline. + + Makes sure the breakpoint contains a valid components registered in the pipeline. + If the visit is not given, it is assumed to be 0, it will break on the first visit. + + :param pipeline_breakpoint: Tuple of component name and visit count at which the pipeline should stop. + :returns: + Tuple of component name and visit count representing the `pipeline_breakpoint` + """ + + if pipeline_breakpoint and pipeline_breakpoint[0] not in graph.nodes: + raise ValueError(f"pipeline_breakpoint {pipeline_breakpoint} is not a registered component in the pipeline") + valid_breakpoint: Tuple[str, int] = ( + (pipeline_breakpoint[0], 0 if pipeline_breakpoint[1] is None else pipeline_breakpoint[1]) + if pipeline_breakpoint + else None + ) + return valid_breakpoint + + +def _validate_pipeline_state(resume_state: Dict[str, Any], graph: MultiDiGraph) -> None: + """ + Validates that the resume_state contains valid configuration for the current pipeline. + + Raises a PipelineInvalidResumeStateError if any component is missing or if the state structure is invalid. + + :param resume_state: The saved state to validate. + """ + + pipeline_state = resume_state["pipeline_state"] + valid_components = set(graph.nodes.keys()) + + # Check if the ordered_component_names are valid components in the pipeline + missing_ordered = set(pipeline_state["ordered_component_names"]) - valid_components + if missing_ordered: + raise PipelineInvalidResumeStateError( + f"Invalid resume state: components {missing_ordered} in 'ordered_component_names' " + f"are not part of the current pipeline." + ) + + # Check if the input_data is valid components in the pipeline + missing_input = set(resume_state["input_data"].keys()) - valid_components + if missing_input: + raise PipelineInvalidResumeStateError( + f"Invalid resume state: components {missing_input} in 'input_data' are not part of the current pipeline." + ) + + # Validate 'component_visits' + missing_visits = set(pipeline_state["component_visits"].keys()) - valid_components + if missing_visits: + raise PipelineInvalidResumeStateError( + f"Invalid resume state: components {missing_visits} in 'component_visits' " + f"are not part of the current pipeline." + ) + + logger.info( + f"Resuming pipeline from component: {resume_state['pipeline_breakpoint']['component']} " + f"(visit {resume_state['pipeline_breakpoint']['visits']})" + ) + + +def _validate_resume_state(resume_state: Dict[str, Any]) -> None: + """ + Validates the loaded pipeline resume_state. + + Ensures that the resume_state contains required keys: "input_data", "pipeline_breakpoint", and "pipeline_state". + + Raises: + ValueError: If required keys are missing or the component sets are inconsistent. + """ + + # top-level state has all required keys + required_top_keys = {"input_data", "pipeline_breakpoint", "pipeline_state"} + missing_top = required_top_keys - resume_state.keys() + if missing_top: + raise ValueError(f"Invalid state file: missing required keys {missing_top}") + + # pipeline_state has the necessary keys + pipeline_state = resume_state["pipeline_state"] + required_pipeline_keys = {"inputs", "component_visits", "ordered_component_names"} + missing_pipeline = required_pipeline_keys - pipeline_state.keys() + if missing_pipeline: + raise ValueError(f"Invalid pipeline_state: missing required keys {missing_pipeline}") + + # component_visits and ordered_component_names must be consistent + components_in_state = set(pipeline_state["component_visits"].keys()) + components_in_order = set(pipeline_state["ordered_component_names"]) + + if components_in_state != components_in_order: + raise ValueError( + f"Inconsistent state: components in pipeline_state['component_visits'] {components_in_state} " + f"do not match components in ordered_component_names {components_in_order}" + ) + + logger.info("Passed resume state validated successfully.") + + +def load_state(file_path: Union[str, Path]) -> Dict[str, Any]: + """ + Load a saved pipeline state. + + :param file_path: Path to the resume_state file + :returns: + Dict containing the loaded resume_state. + """ + + file_path = Path(file_path) + + try: + with open(file_path, "r", encoding="utf-8") as f: + state = json.load(f) + except FileNotFoundError: + raise FileNotFoundError(f"File not found: {file_path}") + except json.JSONDecodeError as e: + raise json.JSONDecodeError(f"Invalid JSON file {file_path}: {str(e)}", e.doc, e.pos) + except IOError as e: + raise IOError(f"Error reading {file_path}: {str(e)}") + + try: + _validate_resume_state(resume_state=state) + except ValueError as e: + raise ValueError(f"Invalid pipeline state from {file_path}: {str(e)}") + + logger.info(f"Successfully loaded pipeline state from: {file_path}") + return state + + +def _save_state( + inputs: Dict[str, Any], + component_name: str, + component_visits: Dict[str, int], + debug_path: Optional[Union[str, Path]] = None, + original_input_data: Optional[Dict[str, Any]] = None, + ordered_component_names: Optional[List[str]] = None, +) -> Dict[str, Any]: + """ + Save the pipeline state to a file. + + :param inputs: The current pipeline state inputs. + :param component_name: The name of the component that triggered the breakpoint. + :param component_visits: The visit count of the component that triggered the breakpoint. + :param debug_path: The path to save the state to. + :param original_input_data: The original input data. + :param ordered_component_names: The ordered component names. + :raises: + Exception: If the debug_path is not a string or a Path object, or if saving the JSON state fails. + + :returns: + The dictionary containing the state of the pipeline containing the following keys: + - input_data: The original input data passed to the pipeline. + - timestamp: The timestamp of the breakpoint. + - pipeline_breakpoint: The component name and visit count that triggered the breakpoint. + - pipeline_state: The state of the pipeline when the breakpoint was triggered containing the following keys: + - inputs: The current state of inputs for pipeline components. + - component_visits: The visit count of the components when the breakpoint was triggered. + - ordered_component_names: The order of components in the pipeline. + """ + dt = datetime.now() + + state = { + "input_data": _serialize_component_input(original_input_data), # original input data + "timestamp": dt.isoformat(), + "pipeline_breakpoint": {"component": component_name, "visits": component_visits[component_name]}, + "pipeline_state": { + "inputs": _serialize_component_input(inputs), # current pipeline state inputs + "component_visits": component_visits, + "ordered_component_names": ordered_component_names, + }, + } + + if not debug_path: + return state + + debug_path = Path(debug_path) if isinstance(debug_path, str) else debug_path + if not isinstance(debug_path, Path): + raise ValueError("Debug path must be a string or a Path object.") + debug_path.mkdir(exist_ok=True) + file_name = Path(f"{component_name}_{dt.strftime('%Y_%m_%d_%H_%M_%S')}.json") + + try: + with open(debug_path / file_name, "w") as f_out: + json.dump(state, f_out, indent=2) + logger.info(f"Pipeline state saved at: {file_name}") + + return state + + except Exception as e: + logger.error(f"Failed to save pipeline state: {str(e)}") + raise + + +def _deserialize_component_input(value: Any) -> Any: # noqa: PLR0911 + """ + Tries to deserialize any type of input that can be passed to as input to a pipeline component. + + For primitive values, it returns the value as is, but for complex types, it tries to deserialize them. + """ + + # None or primitive types are returned as is + if not value or isinstance(value, (str, int, float, bool)): + return value + + # list of primitive types are returned as is + if isinstance(value, list) and all(isinstance(i, (str, int, float, bool)) for i in value): + return value + + if isinstance(value, list): + # list of lists are called recursively + if all(isinstance(i, list) for i in value): + return [_deserialize_component_input(i) for i in value] + # list of dicts are called recursively + if all(isinstance(i, dict) for i in value): + return [_deserialize_component_input(i) for i in value] + + # Define the mapping of types to their deserialization functions + _type_deserializers = { + "Answer": Answer.from_dict, + "ChatMessage": ChatMessage.from_dict, + "Document": Document.from_dict, + "ExtractedAnswer": ExtractedAnswer.from_dict, + "GeneratedAnswer": GeneratedAnswer.from_dict, + "SparseEmbedding": SparseEmbedding.from_dict, + } + + # check if the dictionary has a "_type" key and if it's a known type + if isinstance(value, dict): + if "_type" in value: + type_name = value.pop("_type") + if type_name in _type_deserializers: + return _type_deserializers[type_name](value) + + # If not a known type, recursively deserialize each item in the dictionary + return {k: _deserialize_component_input(v) for k, v in value.items()} + + return value + + +def _transform_json_structure(data: Union[Dict[str, Any], List[Any], Any]) -> Any: + """ + Transforms a JSON structure by removing the 'sender' key and moving the 'value' to the top level. + + For example: + "key": [{"sender": null, "value": "some value"}] -> "key": "some value" + + :param data: The JSON structure to transform. + :returns: The transformed structure. + """ + if isinstance(data, dict): + # If this dict has both 'sender' and 'value', return just the value + if "value" in data and "sender" in data: + return data["value"] + # Otherwise, recursively process each key-value pair + return {k: _transform_json_structure(v) for k, v in data.items()} + + if isinstance(data, list): + # First, transform each item in the list. + transformed = [_transform_json_structure(item) for item in data] + # If the original list has exactly one element and that element was a dict + # with 'sender' and 'value', then unwrap the list. + if len(data) == 1 and isinstance(data[0], dict) and "value" in data[0] and "sender" in data[0]: + return transformed[0] + return transformed + + # For other data types, just return the value as is. + return data + + +def _serialize_component_input(value: Any) -> Any: + """ + Serializes, so it can be saved to a file, any type of input to a pipeline component. + + :param value: The value to serialize. + :returns: The serialized value that can be saved to a file. + """ + value = _transform_json_structure(value) + if hasattr(value, "to_dict") and callable(getattr(value, "to_dict")): + serialized_value = value.to_dict() + serialized_value["_type"] = value.__class__.__name__ + return serialized_value + + # this is a hack to serialize inputs that don't have a to_dict + elif hasattr(value, "__dict__"): + return { + "_type": value.__class__.__name__, + "attributes": value.__dict__, + } + + # recursively serialize all inputs in a dict + elif isinstance(value, dict): + return {k: _serialize_component_input(v) for k, v in value.items()} + + # recursively serialize all inputs in lists or tuples + elif isinstance(value, list): + return [_serialize_component_input(item) for item in value] + + return value diff --git a/haystack_experimental/core/pipeline/pipeline.py b/haystack_experimental/core/pipeline/pipeline.py index 68217fcd..cbeea62d 100644 --- a/haystack_experimental/core/pipeline/pipeline.py +++ b/haystack_experimental/core/pipeline/pipeline.py @@ -5,26 +5,20 @@ # pylint: disable=too-many-return-statements, too-many-positional-arguments -import json from copy import deepcopy -from datetime import datetime from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Optional, Set, Tuple, Union -from haystack import Answer, Document, ExtractedAnswer, logging, tracing -from haystack.components.joiners import DocumentJoiner -from haystack.core.component import Component +from haystack import logging, tracing from haystack.core.pipeline.base import ComponentPriority from haystack.core.pipeline.pipeline import Pipeline as HaystackPipeline -from haystack.dataclasses import ChatMessage, GeneratedAnswer, SparseEmbedding from haystack.telemetry import pipeline_running -from haystack_experimental.core.errors import ( - PipelineBreakpointException, - PipelineInvalidResumeStateError, -) +from haystack_experimental.core.errors import PipelineBreakpointException, PipelineInvalidResumeStateError from haystack_experimental.core.pipeline.base import PipelineBase +from .breakpoint import _deserialize_component_input, _save_state, _validate_breakpoint, _validate_pipeline_state + logger = logging.getLogger(__name__) @@ -41,8 +35,7 @@ def run( # noqa: PLR0915, PLR0912 self, data: Dict[str, Any], include_outputs_from: Optional[Set[str]] = None, - breakpoints: Optional[Set[Tuple[str, Optional[int]]]] = None, - *, + pipeline_breakpoint: Optional[Tuple[str, Optional[int]]] = None, resume_state: Optional[Dict[str, Any]] = None, debug_path: Optional[Union[str, Path]] = None, ) -> Dict[str, Any]: @@ -121,8 +114,8 @@ def run( # noqa: PLR0915, PLR0912 invoked multiple times (in a loop), only the last-produced output is included. - :param breakpoints: - Set of tuples of component names and visit counts at which the pipeline should break execution. + :param pipeline_breakpoint: + Tuple of component name and visit count at which the pipeline should break execution. If the visit count is not given, it is assumed to be 0, it will break on the first visit. :param resume_state: @@ -146,17 +139,19 @@ def run( # noqa: PLR0915, PLR0912 :raises PipelineMaxComponentRuns: If a Component reaches the maximum number of times it can be run in this Pipeline. :raises PipelineBreakpointException: - When a breakpoint is triggered. Contains the component name, state, and partial results. + When a pipeline_breakpoint is triggered. Contains the component name, state, and partial results. """ pipeline_running(self) - if breakpoints and resume_state: - logger.warning( - "Breakpoints cannot be provided when resuming a pipeline. All breakpoints will be ignored.", + if pipeline_breakpoint and resume_state: + msg = ( + "pipeline_breakpoint and resume_state cannot be provided at the same time. " + "The pipeline run will be aborted." ) + raise PipelineInvalidResumeStateError(message=msg) - # make sure breakpoints are valid and have a default visit count - validated_breakpoints = self._validate_breakpoints(breakpoints) if breakpoints else None + # make sure pipeline_breakpoint is valid and have a default visit count + validated_breakpoint = _validate_breakpoint(pipeline_breakpoint, self.graph) if pipeline_breakpoint else None # TODO: Remove this warmup once we can check reliably whether a component has been warmed up or not # As of now it's here to make sure we don't have failing tests that assume warm_up() is called in run() @@ -206,98 +201,100 @@ def run( # noqa: PLR0915, PLR0912 # check if pipeline is blocked before execution self.validate_pipeline(priority_queue) - try: - while True: - candidate = self._get_next_runnable_component(priority_queue, component_visits) - if candidate is None: - break + while True: + candidate = self._get_next_runnable_component(priority_queue, component_visits) + if candidate is None: + break - priority, component_name, component = candidate + priority, component_name, component = candidate - if len(priority_queue) > 0 and priority in [ComponentPriority.DEFER, ComponentPriority.DEFER_LAST]: - component_name, topological_sort = self._tiebreak_waiting_components( - component_name=component_name, - priority=priority, - priority_queue=priority_queue, - topological_sort=cached_topological_sort, - ) - cached_topological_sort = topological_sort - component = self._get_component_with_graph_metadata_and_visits( - component_name, component_visits[component_name] - ) + if len(priority_queue) > 0 and priority in [ComponentPriority.DEFER, ComponentPriority.DEFER_LAST]: + component_name, topological_sort = self._tiebreak_waiting_components( + component_name=component_name, + priority=priority, + priority_queue=priority_queue, + topological_sort=cached_topological_sort, + ) - # this breaks the pipeline breakpoints - component_inputs = self._consume_component_inputs( - component_name=component_name, component=component, inputs=inputs + cached_topological_sort = topological_sort + component = self._get_component_with_graph_metadata_and_visits( + component_name, component_visits[component_name] ) - # We need to add missing defaults using default values from input sockets because the run signature - # might not provide these defaults for components with inputs defined dynamically upon component - # initialization - component_inputs = self._add_missing_input_defaults(component_inputs, component["input_sockets"]) - - # Deserialize the component_inputs if they are passed in resume state - # this check will prevent other component_inputs generated at runtime from being deserialized - if resume_state and component_name in resume_state["pipeline_state"]["inputs"].keys(): - for key, value in component_inputs.items(): - component_inputs[key] = deserialize_component_input(value) - - if validated_breakpoints and not resume_state: - state_inputs_serialised = remove_unserializable_data(deepcopy(inputs)) - # inject the component_inputs into the state_inputs so we can this component init params in - # the JSON state - state_inputs_serialised[component_name] = remove_unserializable_data(deepcopy(component_inputs)) - - Pipeline._check_breakpoints( - breakpoints=validated_breakpoints, - component_name=component_name, - component_visits=component_visits, + + is_resume = bool(resume_state and resume_state["pipeline_breakpoint"]["component"] == component_name) + component_inputs = self._consume_component_inputs( + component_name=component_name, component=component, inputs=inputs, is_resume=is_resume + ) + + # We need to add missing defaults using default values from input sockets because the run signature + # might not provide these defaults for components with inputs defined dynamically upon component + # initialization + component_inputs = self._add_missing_input_defaults(component_inputs, component["input_sockets"]) + + # Scenario 1: Resume state is provided to resume the pipeline at a specific component + + # Deserialize the component_inputs if they are passed in resume state + # this check will prevent other component_inputs generated at runtime from being deserialized + if resume_state and component_name in resume_state["pipeline_state"]["inputs"].keys(): + for key, value in component_inputs.items(): + component_inputs[key] = _deserialize_component_input(value) + + # Scenario 2: pipeline_breakpoint is provided to stop the pipeline at + # a specific component and visit count + + if validated_breakpoint is not None: + breakpoint_component, visit_count = validated_breakpoint + breakpoint_triggered = bool( + breakpoint_component == component_name and visit_count == component_visits[component_name] + ) + if breakpoint_triggered: + state_inputs_serialised = deepcopy(inputs) + state_inputs_serialised[component_name] = deepcopy(component_inputs) + + _save_state( inputs=state_inputs_serialised, + component_name=str(component_name), + component_visits=component_visits, debug_path=debug_path, original_input_data=data, ordered_component_names=ordered_component_names, ) + msg = f"Breaking at component {component_name} visit count {component_visits[component_name]}" + logger.info(msg) + raise PipelineBreakpointException( + message=msg, + component=component_name, + state=state_inputs_serialised, + results=pipeline_outputs, + ) - # the _consume_component_inputs() when applied to the DocumentJoiner inputs wraps 'documents' in an - # extra list, so there's a 3 level deep list, we need to flatten it to 2 levels only - instance: Component = component["instance"] - if resume_state and isinstance(instance, DocumentJoiner): # noqa: SIM102 - if isinstance(component_inputs["documents"], list): # noqa: SIM102 - if isinstance(component_inputs["documents"][0], list): # noqa: SIM102 - if isinstance(component_inputs["documents"][0][0], list): # noqa: SIM102 - component_inputs["documents"] = component_inputs["documents"][0] - - component_outputs = self._run_component( - component_name=component_name, - component=component, - inputs=component_inputs, # the inputs to the current component - component_visits=component_visits, - parent_span=span, - ) - - # Updates global input state with component outputs and returns outputs that should go to - # pipeline outputs. - component_pipeline_outputs = self._write_component_outputs( - component_name=component_name, - component_outputs=component_outputs, - inputs=inputs, - receivers=cached_receivers[component_name], - include_outputs_from=include_outputs_from, - ) + component_outputs = self._run_component( + component_name=component_name, + component=component, + inputs=component_inputs, # the inputs to the current component + component_visits=component_visits, + parent_span=span, + ) - if component_pipeline_outputs: - pipeline_outputs[component_name] = deepcopy(component_pipeline_outputs) - if self._is_queue_stale(priority_queue): - priority_queue = self._fill_queue(ordered_component_names, inputs, component_visits) + # Updates global input state with component outputs and returns outputs that should go to + # pipeline outputs. + component_pipeline_outputs = self._write_component_outputs( + component_name=component_name, + component_outputs=component_outputs, + inputs=inputs, + receivers=cached_receivers[component_name], + include_outputs_from=include_outputs_from, + ) - except PipelineBreakpointException as e: - # Add the current pipeline results to the exception - e.results = pipeline_outputs - raise + if component_pipeline_outputs: + pipeline_outputs[component_name] = deepcopy(component_pipeline_outputs) + if self._is_queue_stale(priority_queue): + priority_queue = self._fill_queue(ordered_component_names, inputs, component_visits) - if breakpoints: - logger.warning(f"Given breakpoint {breakpoints} was never triggered. This is because:") + if pipeline_breakpoint: + logger.warning(f"Given pipeline_breakpoint {pipeline_breakpoint} was never triggered. This is because:") logger.warning("1. The provided component is not a part of the pipeline execution path.") - logger.warning("2. The component did not reach the visit count specified in the breakpoint") + logger.warning("2. The component did not reach the visit count specified in the pipeline_breakpoint") return pipeline_outputs def inject_resume_state_into_graph(self, resume_state): @@ -310,360 +307,13 @@ def inject_resume_state_into_graph(self, resume_state): if not resume_state: raise PipelineInvalidResumeStateError("Cannot inject resume state: resume_state is None") - self._validate_pipeline_state(resume_state) + _validate_pipeline_state(resume_state, graph=self.graph) data = self._prepare_component_input_data(resume_state["pipeline_state"]["inputs"]) component_visits = resume_state["pipeline_state"]["component_visits"] ordered_component_names = resume_state["pipeline_state"]["ordered_component_names"] msg = ( - f"Resuming pipeline from {resume_state['breakpoint']['component']} " - f"visit count {resume_state['breakpoint']['visits']}" + f"Resuming pipeline from {resume_state['pipeline_breakpoint']['component']} " + f"visit count {resume_state['pipeline_breakpoint']['visits']}" ) logger.info(msg) return component_visits, data, resume_state, ordered_component_names - - def _validate_breakpoints(self, breakpoints: Set[Tuple[str, Optional[int]]]) -> Set[Tuple[str, int]]: - """ - Validates the breakpoints passed to the pipeline. - - Make sure they are all valid components registered in the pipeline, - If the visit is not given, it is assumed to be 0, it will break on the first visit. - - :param breakpoints: Set of tuples of component names and visit counts at which the pipeline should stop. - :returns: - Set of valid breakpoints. - """ - - processed_breakpoints: Set[Tuple[str, int]] = set() - - for break_point in breakpoints: - if break_point[0] not in self.graph.nodes: - raise ValueError(f"Breakpoint {break_point} is not a registered component in the pipeline") - valid_breakpoint: Tuple[str, int] = (break_point[0], 0 if break_point[1] is None else break_point[1]) - processed_breakpoints.add(valid_breakpoint) - return processed_breakpoints - - def _validate_pipeline_state(self, resume_state: Dict[str, Any]) -> None: - """ - Validates that the resume_state contains valid configuration for the current pipeline. - - Raises a PipelineRuntimeError if any component is missing or if the state structure is invalid. - - :param resume_state: The saved state to validate. - """ - - pipeline_state = resume_state["pipeline_state"] - valid_components = set(self.graph.nodes.keys()) - - # Check if the ordered_component_names are valid components in the pipeline - missing_ordered = set(pipeline_state["ordered_component_names"]) - valid_components - if missing_ordered: - raise PipelineInvalidResumeStateError( - f"Invalid resume state: components {missing_ordered} in 'ordered_component_names' " - f"are not part of the current pipeline." - ) - - # Check if the input_data is valid components in the pipeline - missing_input = set(resume_state["input_data"].keys()) - valid_components - if missing_input: - raise PipelineInvalidResumeStateError( - f"Invalid resume state: components {missing_input} in 'input_data' " - f"are not part of the current pipeline." - ) - - # Validate 'component_visits' - missing_visits = set(pipeline_state["component_visits"].keys()) - valid_components - if missing_visits: - raise PipelineInvalidResumeStateError( - f"Invalid resume state: components {missing_visits} in 'component_visits' " - f"are not part of the current pipeline." - ) - - logger.info( - f"Resuming pipeline from component: {resume_state['breakpoint']['component']} " - f"(visit {resume_state['breakpoint']['visits']})" - ) - - @staticmethod - def _validate_resume_state(state: Dict[str, Any]) -> None: - """ - Validates the loaded pipeline state. - - Ensures that the state contains required keys: "input_data", "breakpoint", and "pipeline_state". - - Raises: - ValueError: If required keys are missing or the component sets are inconsistent. - """ - - # top-level state has all required keys - required_top_keys = {"input_data", "breakpoint", "pipeline_state"} - missing_top = required_top_keys - state.keys() - if missing_top: - raise ValueError(f"Invalid state file: missing required keys {missing_top}") - - # pipeline_state has the necessary keys - pipeline_state = state["pipeline_state"] - required_pipeline_keys = {"inputs", "component_visits", "ordered_component_names"} - missing_pipeline = required_pipeline_keys - pipeline_state.keys() - if missing_pipeline: - raise ValueError(f"Invalid pipeline_state: missing required keys {missing_pipeline}") - - # component_visits and ordered_component_names must be consistent - components_in_state = set(pipeline_state["component_visits"].keys()) - components_in_order = set(pipeline_state["ordered_component_names"]) - - if components_in_state != components_in_order: - raise ValueError( - f"Inconsistent state: components in pipeline_state['component_visits'] {components_in_state} " - f"do not match components in ordered_component_names {components_in_order}" - ) - - logger.info("Passed resume state validated successfully.") - - @staticmethod - def load_state(file_path: Union[str, Path]) -> Dict[str, Any]: - """ - Load a saved pipeline state. - - :param file_path: Path to the state file - :returns: - Dict containing the loaded state - """ - - file_path = Path(file_path) - - try: - with open(file_path, "r", encoding="utf-8") as f: - state = json.load(f) - except FileNotFoundError: - raise FileNotFoundError(f"File not found: {file_path}") - except json.JSONDecodeError as e: - raise json.JSONDecodeError(f"Invalid JSON file {file_path}: {str(e)}", e.doc, e.pos) - except IOError as e: - raise IOError(f"Error reading {file_path}: {str(e)}") - - try: - Pipeline._validate_resume_state(state=state) - except ValueError as e: - raise ValueError(f"Invalid pipeline state from {file_path}: {str(e)}") - - logger.info(f"Successfully loaded pipeline state from: {file_path}") - return state - - @staticmethod - def save_state( - inputs: Dict[str, Any], - component_name: str, - component_visits: Dict[str, int], - callback_fun: Optional[Callable[..., Any]] = None, - debug_path: Optional[Union[str, Path]] = None, - original_input_data: Optional[Dict[str, Any]] = None, - ordered_component_names: Optional[List[str]] = None, - ) -> Dict[str, Any]: - """ - If a debug_path is given it saves the JSON state of the pipeline at a given component visit count in a file. - - If debug_path is not given, it returns the JSON state as a dictionary without saving it to a file. - - :raises: - Exception: If the debug_path is not a string or a Path object, or if saving the JSON state fails. - - :returns: - The saved state dictionary - """ - dt = datetime.now() - state = { - "input_data": serialize_component_input(original_input_data), # original input data - "timestamp": dt.isoformat(), - "breakpoint": {"component": component_name, "visits": component_visits[component_name]}, - "pipeline_state": { - "inputs": serialize_component_input(inputs), # current pipeline state inputs - "component_visits": component_visits, - "ordered_component_names": ordered_component_names, - }, - } - if not debug_path: - return state - - if isinstance(debug_path, str): - debug_path = Path(debug_path) - if not isinstance(debug_path, Path): - raise ValueError("Debug path must be a string or a Path object.") - debug_path.mkdir(exist_ok=True) - file_name = Path(f"{component_name}_{dt.strftime('%Y_%m_%d_%H_%M_%S')}.json") - - try: - with open(debug_path / file_name, "w") as f_out: - json.dump(state, f_out, indent=2) - logger.info(f"Pipeline state saved at: {file_name}") - - # pass the state to some user-defined callback function - if callback_fun is not None: - callback_fun(state) - - return state - - except Exception as e: - logger.error(f"Failed to save pipeline state: {str(e)}") - raise - - @staticmethod - def _check_breakpoints( - breakpoints: Set[Tuple[str, int]], - component_name: str, - component_visits: Dict[str, int], - inputs: Dict[str, Any], - debug_path: Optional[Union[str, Path]] = None, - original_input_data: Optional[Dict[str, Any]] = None, - ordered_component_names: Optional[List[str]] = None, - ) -> None: - """ - Check if the `component_name` is in the breakpoints and if it should break. - - :param breakpoints: Set of tuples of component names and visit counts at which the pipeline should stop. - :param component_name: Name of the component to check. - :param component_visits: The number of times the component has been visited. - :param inputs: The inputs to the pipeline. - :raises PipelineBreakpointException: When a breakpoint is triggered, with component state information. - """ - matching_breakpoints = [bp for bp in breakpoints if bp[0] == component_name] - - for bp in matching_breakpoints: - visit_count = bp[1] - # break only if the visit count is the same - if visit_count == component_visits[component_name]: - msg = f"Breaking at component {component_name} visit count {component_visits[component_name]}" - logger.info(msg) - state = Pipeline.save_state( - inputs=inputs, - component_name=str(component_name), - component_visits=component_visits, - debug_path=debug_path, - original_input_data=original_input_data, - ordered_component_names=ordered_component_names, - ) - raise PipelineBreakpointException(msg, component=component_name, state=state) - - -def deserialize_component_input(value): # noqa: PLR0911 - """ - Tries to deserialize any type of input that can be passed to as input to a pipeline component. - - For primitive values, it returns the value as is, but for complex types, it tries to deserialize them. - """ - - # None or primitive types are returned as is - if not value or isinstance(value, (str, int, float, bool)): - return value - - # list of primitive types are returned as is - if isinstance(value, list) and all(isinstance(i, (str, int, float, bool)) for i in value): - return value - - if isinstance(value, list): - # list of lists are called recursively - if all(isinstance(i, list) for i in value): - return [deserialize_component_input(i) for i in value] - # list of dicts are called recursively - if all(isinstance(i, dict) for i in value): - return [deserialize_component_input(i) for i in value] - - # Define the mapping of types to their deserialization functions - _type_deserializers = { - "Answer": Answer.from_dict, - "ChatMessage": ChatMessage.from_dict, - "Document": Document.from_dict, - "ExtractedAnswer": ExtractedAnswer.from_dict, - "GeneratedAnswer": GeneratedAnswer.from_dict, - "SparseEmbedding": SparseEmbedding.from_dict, - } - - # check if the dictionary has a "_type" key and if it's a known type - if isinstance(value, dict): - if "_type" in value: - type_name = value.pop("_type") - if type_name in _type_deserializers: - return _type_deserializers[type_name](value) - - # If not a known type, recursively deserialize each item in the dictionary - return {k: deserialize_component_input(v) for k, v in value.items()} - - return value - - -def transform_json_structure(data: Union[Dict[str, Any], List[Any], Any]) -> Any: - """ - Transforms a JSON structure by removing the 'sender' key and moving the 'value' to the top level. - - For example: - "key": [{"sender": null, "value": "some value"}] -> "key": "some value" - """ - if isinstance(data, dict): - # If this dict has both 'sender' and 'value', return just the value - if "value" in data and "sender" in data: - return data["value"] - # Otherwise, recursively process each key-value pair - return {k: transform_json_structure(v) for k, v in data.items()} - - elif isinstance(data, list): - # First, transform each item in the list. - transformed = [transform_json_structure(item) for item in data] - # If the original list has exactly one element and that element was a dict - # with 'sender' and 'value', then unwrap the list. - if len(data) == 1 and isinstance(data[0], dict) and "value" in data[0] and "sender" in data[0]: - return transformed[0] - return transformed - - else: - # For other data types, just return the value as is. - return data - - -def remove_unserializable_data(value: Any) -> Any: - """ - Removes certain unserializable data which is not needed for the pipeline state. - """ - - if isinstance(value, ChatMessage): # noqa: SIM102 - if "usage" in value.meta: # noqa: SIM102 - value.meta["usage"].pop("completion_tokens_details", None) - value.meta["usage"].pop("prompt_tokens_details", None) - - if isinstance(value, GeneratedAnswer): # noqa: SIM102 - if value.meta and "usage" in value.meta: # noqa: SIM102 - value.meta.pop("usage", None) - - # all_messages contains a list of unserialized ChatMessages - # TODO: we should find a better way to handle this - if value.meta and "all_messages" in value.meta: - value.meta.pop("all_messages", None) - - return value - - -def serialize_component_input(value: Any) -> Any: - """ - Serializes, so it can be saved to a file, any type of input to a pipeline component. - """ - value = remove_unserializable_data(value) - value = transform_json_structure(value) - if hasattr(value, "to_dict") and callable(getattr(value, "to_dict")): - serialized_value = value.to_dict() - serialized_value["_type"] = value.__class__.__name__ - return serialized_value - - # this is a hack to serialize inputs that don't have a to_dict - elif hasattr(value, "__dict__"): - return { - "_type": value.__class__.__name__, - "attributes": value.__dict__, - } - - # recursively serialize all inputs in a dict - elif isinstance(value, dict): - return {k: serialize_component_input(v) for k, v in value.items()} - - # recursively serialize all inputs in lists or tuples - elif isinstance(value, list): - return [serialize_component_input(item) for item in value] - - return value diff --git a/haystack_experimental/dataclasses/__init__.py b/haystack_experimental/dataclasses/__init__.py index 429ab99b..bc7446d4 100644 --- a/haystack_experimental/dataclasses/__init__.py +++ b/haystack_experimental/dataclasses/__init__.py @@ -10,9 +10,11 @@ _import_structure = { "chat_message": ["ChatMessage"], "image_content": ["ImageContent"], + "answer": ["GeneratedAnswer"], } if TYPE_CHECKING: + from .answer import GeneratedAnswer from .chat_message import ChatMessage from .image_content import ImageContent else: diff --git a/haystack_experimental/dataclasses/answer.py b/haystack_experimental/dataclasses/answer.py new file mode 100644 index 00000000..eaed7c49 --- /dev/null +++ b/haystack_experimental/dataclasses/answer.py @@ -0,0 +1,55 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import Any, Dict + +from haystack.core.serialization import default_from_dict, default_to_dict +from haystack.dataclasses import ChatMessage, Document +from haystack.dataclasses import GeneratedAnswer as HaystackGeneratedAnswer + + +@dataclass +class GeneratedAnswer(HaystackGeneratedAnswer): + def to_dict(self) -> Dict[str, Any]: + """ + Serialize the object to a dictionary. + + :returns: + Serialized dictionary representation of the object. + """ + documents = [doc.to_dict(flatten=False) for doc in self.documents] + + # Serialize ChatMessage objects to dicts + meta = self.meta + all_messages = meta.get("all_messages") + + # all_messages is either a list of ChatMessage objects or a list of strings + if all_messages and isinstance(all_messages[0], ChatMessage): + meta = {**meta, "all_messages": [msg.to_dict() for msg in all_messages]} + + return default_to_dict(self, data=self.data, query=self.query, documents=documents, meta=meta) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "GeneratedAnswer": + """ + Deserialize the object from a dictionary. + + :param data: + Dictionary representation of the object. + + :returns: + Deserialized object. + """ + init_params = data.get("init_parameters", {}) + + if (documents := init_params.get("documents")) is not None: + init_params["documents"] = [Document.from_dict(d) for d in documents] + + meta = init_params.get("meta", {}) + if (all_messages := meta.get("all_messages")) is not None and isinstance(all_messages[0], dict): + meta["all_messages"] = [ChatMessage.from_dict(m) for m in all_messages] + init_params["meta"] = meta + + return default_from_dict(cls, data) diff --git a/test/conftest.py b/test/conftest.py index ee8e7b0d..759101f9 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -13,6 +13,7 @@ from haystack.testing.test_utils import set_all_seeds from haystack import tracing, component +from haystack_experimental.core.pipeline.breakpoint import load_state set_all_seeds(0) @@ -75,7 +76,7 @@ def load_and_resume_pipeline_state(pipeline, output_directory: Path, component: f_name = Path(full_path).name if str(f_name).startswith(component): file_found = True - resume_state = pipeline.load_state(full_path) + resume_state = load_state(full_path) return pipeline.run(data=data, resume_state=resume_state) if not file_found: diff --git a/test/core/pipeline/test_breakpoint.py b/test/core/pipeline/test_breakpoint.py new file mode 100644 index 00000000..8da5e37b --- /dev/null +++ b/test/core/pipeline/test_breakpoint.py @@ -0,0 +1,291 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import json +import pytest + +from haystack.components.joiners import BranchJoiner +from haystack_experimental.core.pipeline.pipeline import Pipeline +from haystack_experimental.core.pipeline.breakpoint import _transform_json_structure, _serialize_component_input, _deserialize_component_input, load_state, _validate_breakpoint, _validate_resume_state + + +class TestBreakpoint: + """ + This class contains only unit tests for the breakpoint module. + """ + + def test_validate_breakpoint(self): + # simple pipeline + joiner_1 = BranchJoiner(type_=str) + joiner_2 = BranchJoiner(type_=str) + pipeline = Pipeline() + pipeline.add_component("comp1", joiner_1) + pipeline.add_component("comp2", joiner_2) + pipeline.connect("comp1", "comp2") + + # valid breakpoints + breakpoints = ("comp1", 0) + validated = _validate_breakpoint(breakpoints, pipeline.graph) + assert validated == ("comp1", 0) + + # should default to 0 + breakpoints = ("comp1", None) + validated = _validate_breakpoint(breakpoints, pipeline.graph) + assert validated == ("comp1", 0) + + # should remain as it is + breakpoints = ("comp1", -1) + validated = _validate_breakpoint(breakpoints, pipeline.graph) + assert validated == ("comp1", -1) + + # contains invalid components + breakpoints = ("comp3", 0) + with pytest.raises(ValueError, match="pipeline_breakpoint .* is not a registered component"): + _validate_breakpoint(breakpoints, pipeline.graph) + + # no breakpoints are defined + breakpoint = None + validated = _validate_breakpoint(breakpoint, pipeline.graph) + assert validated is None + + +def test_transform_json_structure_unwraps_sender_value(): + data = { + "key1": [{"sender": None, "value": "some value"}], + "key2": [{"sender": "comp1", "value": 42}], + "key3": "direct value" + } + + result = _transform_json_structure(data) + + assert result == { + "key1": "some value", + "key2": 42, + "key3": "direct value" + } + +def test_transform_json_structure_handles_nested_structures(): + data = { + "key1": [{"sender": None, "value": "value1"}], + "key2": { + "nested": [{"sender": "comp1", "value": "value2"}], + "direct": "value3" + }, + "key3": [ + [{"sender": None, "value": "value4"}], + [{"sender": "comp2", "value": "value5"}] + ] + } + + result = _transform_json_structure(data) + + assert result == { + "key1": "value1", + "key2": { + "nested": "value2", + "direct": "value3" + }, + "key3": [ + "value4", + "value5" + ] + } + +def test_serialize_component_input_handles_objects_with_to_dict(): + class TestObject: + def __init__(self, value): + self.value = value + + def to_dict(self): + return {"value": self.value} + + obj = TestObject("test") + result = _serialize_component_input(obj) + assert result == { + "_type": "TestObject", + "value": "test" + } + +def test_serialize_component_input_handles_objects_without_to_dict(): + class TestObject: + def __init__(self, value): + self.value = value + + obj = TestObject("test") + result = _serialize_component_input(obj) + assert result == { + "_type": "TestObject", + "attributes": {"value": "test"} + } + +def test_serialize_component_input_handles_nested_structures(): + class TestObject: + def __init__(self, value): + self.value = value + + def to_dict(self): + return {"value": self.value} + + obj = TestObject("test") + data = { + "key1": obj, + "key2": [obj, "string"], + "key3": {"nested": obj} + } + result = _serialize_component_input(data) + + assert result["key1"]["_type"] == "TestObject" + assert result["key2"][0]["_type"] == "TestObject" + assert result["key2"][1] == "string" + assert result["key3"]["nested"]["_type"] == "TestObject" + +def test_deserialize_component_input_handles_primitive_types(): + data = { + "string": "test", + "int": 42, + "float": 3.14, + "bool": True, + "none": None + } + result = _deserialize_component_input(data) + assert result == data + +def test_deserialize_component_input_handles_lists(): + data = { + "primitive_list": [1, 2, 3], + "mixed_list": [1, "string", True] + } + result = _deserialize_component_input(data) + assert result == data + +def test_deserialize_component_input_handles_dicts(): + data = { + "key1": "value1", + "key2": {"nested": "value2"} + } + result = _deserialize_component_input(data) + assert result == data + +def test_deserialize_component_input_handles_nested_lists(): + """Test that _deserialize_component_input handles nested lists""" + data = { + "nested_list": [[1, 2], [3, 4]], + "mixed_nested": [[1, "string"], [True, 3.14]] + } + + result = _deserialize_component_input(data) + + assert result == data + +def test_deserialize_component_input_handles_nested_dicts(): + """Test that _deserialize_component_input handles nested dictionaries""" + data = { + "key1": { + "nested1": "value1", + "nested2": { + "deep": "value2" + } + } + } + + result = _deserialize_component_input(data) + + assert result == data + +def test_deserialize_component_input_handles_empty_structures(): + """Test that _deserialize_component_input handles empty structures""" + data = { + "empty_list": [], + "empty_dict": {}, + "nested_empty": {"empty": []} + } + + result = _deserialize_component_input(data) + + assert result == data + +def test_validate_resume_state_validates_required_keys(): + state = { + "input_data": {}, + "pipeline_breakpoint": {"component": "comp1", "visits": 0} + # Missing pipeline_state + } + + with pytest.raises(ValueError, match="Invalid state file: missing required keys"): + _validate_resume_state(state) + + state = { + "input_data": {}, + "pipeline_breakpoint": {"component": "comp1", "visits": 0}, + "pipeline_state": { + "inputs": {}, + "component_visits": {} + # Missing ordered_component_names + } + } + + with pytest.raises(ValueError, match="Invalid pipeline_state: missing required keys"): + _validate_resume_state(state) + +def test_validate_resume_state_validates_component_consistency(): + state = { + "input_data": {}, + "pipeline_breakpoint": {"component": "comp1", "visits": 0}, + "pipeline_state": { + "inputs": {}, + "component_visits": {"comp1": 0, "comp2": 0}, + "ordered_component_names": ["comp1", "comp3"] # inconsistent with component_visits + } + } + + with pytest.raises(ValueError, match="Inconsistent state: components in pipeline_state"): + _validate_resume_state(state) + +def test_validate_resume_state_validates_valid_state(): + state = { + "input_data": {}, + "pipeline_breakpoint": {"component": "comp1", "visits": 0}, + "pipeline_state": { + "inputs": {}, + "component_visits": {"comp1": 0, "comp2": 0}, + "ordered_component_names": ["comp1", "comp2"] + } + } + + _validate_resume_state(state) # should not raise any exception + +def test_load_state_loads_valid_state(tmp_path): + state = { + "input_data": {}, + "pipeline_breakpoint": {"component": "comp1", "visits": 0}, + "pipeline_state": { + "inputs": {}, + "component_visits": {"comp1": 0, "comp2": 0}, + "ordered_component_names": ["comp1", "comp2"] + } + } + state_file = tmp_path / "state.json" + with open(state_file, "w") as f: + json.dump(state, f) + + loaded_state = load_state(state_file) + assert loaded_state == state + +def test_load_state_handles_invalid_state(tmp_path): + state = { + "input_data": {}, + "pipeline_breakpoint": {"component": "comp1", "visits": 0}, + "pipeline_state": { + "inputs": {}, + "component_visits": {"comp1": 0, "comp2": 0}, + "ordered_component_names": ["comp1", "comp3"] # inconsistent with component_visits + } + } + + state_file = tmp_path / "invalid_state.json" + with open(state_file, "w") as f: + json.dump(state, f) + + with pytest.raises(ValueError, match="Invalid pipeline state from"): + load_state(state_file) diff --git a/test/core/pipeline/test_pipeline.py b/test/core/pipeline/test_pipeline.py index da57274e..cf7559ea 100644 --- a/test/core/pipeline/test_pipeline.py +++ b/test/core/pipeline/test_pipeline.py @@ -11,9 +11,6 @@ from haystack.core.errors import PipelineRuntimeError from haystack_experimental.core.pipeline.pipeline import ( Pipeline, - transform_json_structure, - serialize_component_input, - deserialize_component_input ) @@ -113,278 +110,3 @@ def test_run(self): pp.connect("joiner_1", "joiner_2") _ = pp.run({"value": "test_value"}) - - def test_validate_breakpoints(self): - # simple pipeline - joiner_1 = BranchJoiner(type_=str) - joiner_2 = BranchJoiner(type_=str) - pipeline = Pipeline() - pipeline.add_component("comp1", joiner_1) - pipeline.add_component("comp2", joiner_2) - pipeline.connect("comp1", "comp2") - - # valid breakpoints - breakpoints = {("comp1", 0), ("comp2", 1)} - validated = pipeline._validate_breakpoints(breakpoints) - assert validated == {("comp1", 0), ("comp2", 1)} - - # should default to 0 - breakpoints = {("comp1", None), ("comp2", 1)} - validated = pipeline._validate_breakpoints(breakpoints) - assert validated == {("comp1", 0), ("comp2", 1)} - - # should remain as it is - breakpoints = {("comp1", -1)} - validated = pipeline._validate_breakpoints(breakpoints) - assert validated == {("comp1", -1)} - - # contains invalid components - breakpoints = {("comp1", 0), ("non_existent_component", 1)} - with pytest.raises(ValueError, match="Breakpoint .* is not a registered component"): - pipeline._validate_breakpoints(breakpoints) - - # no breakpoints are defined - breakpoints = set() - validated = pipeline._validate_breakpoints(breakpoints) - assert validated == set() - - -def test_transform_json_structure_unwraps_sender_value(): - data = { - "key1": [{"sender": None, "value": "some value"}], - "key2": [{"sender": "comp1", "value": 42}], - "key3": "direct value" - } - - result = transform_json_structure(data) - - assert result == { - "key1": "some value", - "key2": 42, - "key3": "direct value" - } - -def test_transform_json_structure_handles_nested_structures(): - data = { - "key1": [{"sender": None, "value": "value1"}], - "key2": { - "nested": [{"sender": "comp1", "value": "value2"}], - "direct": "value3" - }, - "key3": [ - [{"sender": None, "value": "value4"}], - [{"sender": "comp2", "value": "value5"}] - ] - } - - result = transform_json_structure(data) - - assert result == { - "key1": "value1", - "key2": { - "nested": "value2", - "direct": "value3" - }, - "key3": [ - "value4", - "value5" - ] - } - -def test_serialize_component_input_handles_objects_with_to_dict(): - class TestObject: - def __init__(self, value): - self.value = value - - def to_dict(self): - return {"value": self.value} - - obj = TestObject("test") - result = serialize_component_input(obj) - assert result == { - "_type": "TestObject", - "value": "test" - } - -def test_serialize_component_input_handles_objects_without_to_dict(): - class TestObject: - def __init__(self, value): - self.value = value - - obj = TestObject("test") - result = serialize_component_input(obj) - assert result == { - "_type": "TestObject", - "attributes": {"value": "test"} - } - -def test_serialize_component_input_handles_nested_structures(): - class TestObject: - def __init__(self, value): - self.value = value - - def to_dict(self): - return {"value": self.value} - - obj = TestObject("test") - data = { - "key1": obj, - "key2": [obj, "string"], - "key3": {"nested": obj} - } - result = serialize_component_input(data) - - assert result["key1"]["_type"] == "TestObject" - assert result["key2"][0]["_type"] == "TestObject" - assert result["key2"][1] == "string" - assert result["key3"]["nested"]["_type"] == "TestObject" - -def test_deserialize_component_input_handles_primitive_types(): - data = { - "string": "test", - "int": 42, - "float": 3.14, - "bool": True, - "none": None - } - result = deserialize_component_input(data) - assert result == data - -def test_deserialize_component_input_handles_lists(): - data = { - "primitive_list": [1, 2, 3], - "mixed_list": [1, "string", True] - } - result = deserialize_component_input(data) - assert result == data - -def test_deserialize_component_input_handles_dicts(): - data = { - "key1": "value1", - "key2": {"nested": "value2"} - } - result = deserialize_component_input(data) - assert result == data - -def test_deserialize_component_input_handles_nested_lists(): - """Test that _deserialize_component_input handles nested lists""" - data = { - "nested_list": [[1, 2], [3, 4]], - "mixed_nested": [[1, "string"], [True, 3.14]] - } - - result = deserialize_component_input(data) - - assert result == data - -def test_deserialize_component_input_handles_nested_dicts(): - """Test that _deserialize_component_input handles nested dictionaries""" - data = { - "key1": { - "nested1": "value1", - "nested2": { - "deep": "value2" - } - } - } - - result = deserialize_component_input(data) - - assert result == data - -def test_deserialize_component_input_handles_empty_structures(): - """Test that _deserialize_component_input handles empty structures""" - data = { - "empty_list": [], - "empty_dict": {}, - "nested_empty": {"empty": []} - } - - result = deserialize_component_input(data) - - assert result == data - -def test_validate_resume_state_validates_required_keys(): - state = { - "input_data": {}, - "breakpoint": {"component": "comp1", "visits": 0} - # Missing pipeline_state - } - - with pytest.raises(ValueError, match="Invalid state file: missing required keys"): - Pipeline._validate_resume_state(state) - - state = { - "input_data": {}, - "breakpoint": {"component": "comp1", "visits": 0}, - "pipeline_state": { - "inputs": {}, - "component_visits": {} - # Missing ordered_component_names - } - } - - with pytest.raises(ValueError, match="Invalid pipeline_state: missing required keys"): - Pipeline._validate_resume_state(state) - -def test_validate_resume_state_validates_component_consistency(): - state = { - "input_data": {}, - "breakpoint": {"component": "comp1", "visits": 0}, - "pipeline_state": { - "inputs": {}, - "component_visits": {"comp1": 0, "comp2": 0}, - "ordered_component_names": ["comp1", "comp3"] # inconsistent with component_visits - } - } - - with pytest.raises(ValueError, match="Inconsistent state: components in pipeline_state"): - Pipeline._validate_resume_state(state) - -def test_validate_resume_state_validates_valid_state(): - state = { - "input_data": {}, - "breakpoint": {"component": "comp1", "visits": 0}, - "pipeline_state": { - "inputs": {}, - "component_visits": {"comp1": 0, "comp2": 0}, - "ordered_component_names": ["comp1", "comp2"] - } - } - - Pipeline._validate_resume_state(state) # should not raise any exception - -def test_load_state_loads_valid_state(tmp_path): - state = { - "input_data": {}, - "breakpoint": {"component": "comp1", "visits": 0}, - "pipeline_state": { - "inputs": {}, - "component_visits": {"comp1": 0, "comp2": 0}, - "ordered_component_names": ["comp1", "comp2"] - } - } - state_file = tmp_path / "state.json" - with open(state_file, "w") as f: - json.dump(state, f) - - loaded_state = Pipeline.load_state(state_file) - assert loaded_state == state - -def test_load_state_handles_invalid_state(tmp_path): - state = { - "input_data": {}, - "breakpoint": {"component": "comp1", "visits": 0}, - "pipeline_state": { - "inputs": {}, - "component_visits": {"comp1": 0, "comp2": 0}, - "ordered_component_names": ["comp1", "comp3"] # inconsistent with component_visits - } - } - - state_file = tmp_path / "invalid_state.json" - with open(state_file, "w") as f: - json.dump(state, f) - - with pytest.raises(ValueError, match="Invalid pipeline state from"): - Pipeline.load_state(state_file) diff --git a/test/core/pipeline/test_pipeline_breakpoints_answer_joiner.py b/test/core/pipeline/test_pipeline_breakpoints_answer_joiner.py index a914cdf8..a5a8485e 100644 --- a/test/core/pipeline/test_pipeline_breakpoints_answer_joiner.py +++ b/test/core/pipeline/test_pipeline_breakpoints_answer_joiner.py @@ -2,7 +2,7 @@ import pytest -from haystack.components.builders.answer_builder import AnswerBuilder +from haystack_experimental.components.builders.answer_builder import AnswerBuilder from haystack.components.generators.chat import OpenAIChatGenerator from haystack.components.joiners import AnswerJoiner from haystack.core.pipeline import Pipeline @@ -66,7 +66,7 @@ def answer_join_pipeline(self, mock_openai_chat_generator): Creates a pipeline with mocked OpenAI components. """ # Create the pipeline with mocked components - pipeline = Pipeline() + pipeline = Pipeline(connection_type_validation=False) pipeline.add_component("gpt-4o", mock_openai_chat_generator("gpt-4o")) pipeline.add_component("gpt-3", mock_openai_chat_generator("gpt-3.5-turbo")) pipeline.add_component("answer_builder_a", AnswerBuilder()) @@ -109,7 +109,7 @@ def test_pipeline_breakpoints_answer_joiner(self, answer_join_pipeline, output_d } try: - _ = answer_join_pipeline.run(data, breakpoints={(component, 0)}, debug_path=str(output_directory)) + _ = answer_join_pipeline.run(data, pipeline_breakpoint=(component, 0), debug_path=str(output_directory)) except PipelineBreakpointException as e: pass diff --git a/test/core/pipeline/test_pipeline_breakpoints_branch_joiner.py b/test/core/pipeline/test_pipeline_breakpoints_branch_joiner.py index 84cd8b15..ad2d0ded 100644 --- a/test/core/pipeline/test_pipeline_breakpoints_branch_joiner.py +++ b/test/core/pipeline/test_pipeline_breakpoints_branch_joiner.py @@ -13,6 +13,7 @@ from haystack.utils.auth import Secret from haystack_experimental.core.errors import PipelineBreakpointException from haystack_experimental.core.pipeline.pipeline import Pipeline +from haystack_experimental.core.pipeline.breakpoint import load_state from unittest.mock import patch class TestPipelineBreakpoints: @@ -106,7 +107,7 @@ def test_pipeline_breakpoints_branch_joiner(self, branch_joiner_pipeline, output } try: - _ = branch_joiner_pipeline.run(data, breakpoints={(component, 0)}, debug_path=str(output_directory)) + _ = branch_joiner_pipeline.run(data, pipeline_breakpoint=(component, 0), debug_path=str(output_directory)) except PipelineBreakpointException as e: pass @@ -116,7 +117,7 @@ def test_pipeline_breakpoints_branch_joiner(self, branch_joiner_pipeline, output f_name = Path(full_path).name if str(f_name).startswith(component): file_found = True - resume_state = Pipeline.load_state(full_path) + resume_state = load_state(full_path) result = branch_joiner_pipeline.run(data, resume_state=resume_state) assert result['validator'] break diff --git a/test/core/pipeline/test_pipeline_breakpoints_list_joiner.py b/test/core/pipeline/test_pipeline_breakpoints_list_joiner.py index c44df13c..cade6bf6 100644 --- a/test/core/pipeline/test_pipeline_breakpoints_list_joiner.py +++ b/test/core/pipeline/test_pipeline_breakpoints_list_joiner.py @@ -120,7 +120,7 @@ def test_list_joiner_pipeline(self, list_joiner_pipeline, output_directory, comp } try: - _ = list_joiner_pipeline.run(data, breakpoints={(component, 0)}, debug_path=str(output_directory)) + _ = list_joiner_pipeline.run(data, pipeline_breakpoint=(component, 0), debug_path=str(output_directory)) except PipelineBreakpointException as e: pass diff --git a/test/core/pipeline/test_pipeline_breakpoints_loops.py b/test/core/pipeline/test_pipeline_breakpoints_loops.py index fce2e735..12b1197d 100644 --- a/test/core/pipeline/test_pipeline_breakpoints_loops.py +++ b/test/core/pipeline/test_pipeline_breakpoints_loops.py @@ -16,6 +16,7 @@ from haystack.utils.auth import Secret from haystack_experimental.core.errors import PipelineBreakpointException from haystack_experimental.core.pipeline.pipeline import Pipeline +from haystack_experimental.core.pipeline.breakpoint import load_state from unittest.mock import patch # Define the component input parameters @@ -200,7 +201,7 @@ def test_pipeline_breakpoints_validation_loop(self, validation_loop_pipeline, ou } try: - _ = validation_loop_pipeline.run(data, breakpoints={(component, 0)}, debug_path=str(output_directory)) + _ = validation_loop_pipeline.run(data, pipeline_breakpoint=(component, 0), debug_path=str(output_directory)) except PipelineBreakpointException: pass @@ -210,7 +211,7 @@ def test_pipeline_breakpoints_validation_loop(self, validation_loop_pipeline, ou f_name = Path(full_path).name if str(f_name).startswith(component): file_found = True - resume_state = Pipeline.load_state(full_path) + resume_state = load_state(full_path) result = validation_loop_pipeline.run(data={}, resume_state=resume_state) # Verify the result contains valid output if "output_validator" in result and "valid_replies" in result["output_validator"]: diff --git a/test/core/pipeline/test_pipeline_breakpoints_rag_hybrid.py b/test/core/pipeline/test_pipeline_breakpoints_rag_hybrid.py index 175e1749..c3996822 100644 --- a/test/core/pipeline/test_pipeline_breakpoints_rag_hybrid.py +++ b/test/core/pipeline/test_pipeline_breakpoints_rag_hybrid.py @@ -2,7 +2,7 @@ import pytest -from haystack.components.builders.answer_builder import AnswerBuilder +from haystack_experimental.components.builders.answer_builder import AnswerBuilder from haystack.components.builders.prompt_builder import PromptBuilder from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder from haystack.components.generators import OpenAIGenerator @@ -208,7 +208,7 @@ def hybrid_rag_pipeline(self, document_store, mock_transformers_similarity_ranke \nQuestion: {{question}} \nAnswer: """ - pipeline = Pipeline() + pipeline = Pipeline(connection_type_validation=False) pipeline.add_component(instance=InMemoryBM25Retriever(document_store=document_store), name="bm25_retriever") # Use the mocked embedder instead of creating a new one @@ -274,7 +274,7 @@ def test_pipeline_breakpoints_hybrid_rag( } try: - _ = hybrid_rag_pipeline.run(data, breakpoints={(component, 0)}, debug_path=str(output_directory)) + _ = hybrid_rag_pipeline.run(data, pipeline_breakpoint=(component, 0), debug_path=str(output_directory)) except PipelineBreakpointException as e: pass diff --git a/test/core/pipeline/test_pipeline_breakpoints_string_joiner.py b/test/core/pipeline/test_pipeline_breakpoints_string_joiner.py index 25957718..5420d53c 100644 --- a/test/core/pipeline/test_pipeline_breakpoints_string_joiner.py +++ b/test/core/pipeline/test_pipeline_breakpoints_string_joiner.py @@ -49,7 +49,7 @@ def test_string_joiner_pipeline(self, string_joiner_pipeline, output_directory, data = {"prompt_builder_1": {"query": string_1}, "prompt_builder_2": {"query": string_2}} try: - _ = string_joiner_pipeline.run(data, breakpoints={(component, 0)}, debug_path=str(output_directory)) + _ = string_joiner_pipeline.run(data, pipeline_breakpoint=(component, 0), debug_path=str(output_directory)) except PipelineBreakpointException as e: pass