Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 8 additions & 16 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,18 @@ jobs:
linting:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6

- name: Get changed files
id: files
uses: tj-actions/changed-files@v47
with:
files_yaml: |
python:
- '**/*.py'
- '!test/**'
pyproject:
- 'pyproject.toml'

files: |
**/*.py
pyproject.toml
files_ignore: |
test/**
.github/**
- uses: actions/setup-python@v6
with:
python-version: "${{ env.PYTHON_VERSION }}"
Expand All @@ -65,15 +64,8 @@ jobs:
- name: Ruff - check format and linting
run: hatch run fmt-check

- name: Pylint
# Running pylint on pyproject.toml causes errors, so we only run it on python files.
if: steps.files.outputs.python_any_changed == 'true'
run: |
hatch run test:lint ${{ steps.files.outputs.python_all_changed_files }}


- name: Typing
if: steps.files.outputs.python_any_changed == 'true' || steps.files.outputs.pyproject_any_changed == 'true'
if: steps.files.outputs.any_changed == 'true'
run: |
mkdir .mypy_cache
hatch run test:types
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ repos:
args: [--markdown-linebreak-ext=md]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.0
rev: v0.15.2
hooks:
- id: ruff-check
args: [ --fix ]
Expand Down
12 changes: 10 additions & 2 deletions haystack_experimental/chat_message_stores/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@
#
# SPDX-License-Identifier: Apache-2.0

from haystack_experimental.chat_message_stores.in_memory import InMemoryChatMessageStore
import sys
from typing import TYPE_CHECKING

__all__ = ["InMemoryChatMessageStore"]
from lazy_imports import LazyImporter

_import_structure = {"in_memory": ["InMemoryChatMessageStore"]}

if TYPE_CHECKING:
from .in_memory import InMemoryChatMessageStore as InMemoryChatMessageStore
else:
sys.modules[__name__] = LazyImporter(name=__name__, module_file=__file__, import_structure=_import_structure)
3 changes: 2 additions & 1 deletion haystack_experimental/chat_message_stores/in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
#
# SPDX-License-Identifier: Apache-2.0

from collections.abc import Iterable
from dataclasses import replace
from typing import Any, Iterable
from typing import Any

from haystack import default_from_dict, default_to_dict
from haystack.dataclasses import ChatMessage, ChatRole
Expand Down
3 changes: 0 additions & 3 deletions haystack_experimental/chat_message_stores/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@

from haystack.dataclasses import ChatMessage

# Ellipsis are needed for the type checker, it's safe to disable module-wide
# pylint: disable=unnecessary-ellipsis


class ChatMessageStore(Protocol):
"""
Expand Down
4 changes: 1 addition & 3 deletions haystack_experimental/components/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#
# SPDX-License-Identifier: Apache-2.0

# pylint: disable=wrong-import-order,wrong-import-position,ungrouped-imports
# ruff: noqa: I001

import inspect
Expand Down Expand Up @@ -240,8 +239,7 @@ def _initialize_fresh_execution(
retriever_kwargs = _select_kwargs(self._chat_message_retriever, chat_message_store_kwargs or {})
if "chat_history_id" in retriever_kwargs:
updated_messages = self._chat_message_retriever.run(
current_messages=exe_context.state.get("messages", []),
**retriever_kwargs,
current_messages=exe_context.state.get("messages", []), **retriever_kwargs
)["messages"]
# We replace the messages in state with the updated messages including chat history
exe_context.state.set("messages", updated_messages, handler_override=replace_values)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,9 @@

from lazy_imports import LazyImporter

_import_structure = {
"dataclasses": ["ToolExecutionDecision"],
"errors": ["HITLBreakpointException"],
"strategies": ["BreakpointConfirmationStrategy"],
}
_import_structure = {"errors": ["HITLBreakpointException"], "strategies": ["BreakpointConfirmationStrategy"]}

if TYPE_CHECKING:
from .dataclasses import ToolExecutionDecision as ToolExecutionDecision
from .errors import HITLBreakpointException as HITLBreakpointException
from .strategies import BreakpointConfirmationStrategy as BreakpointConfirmationStrategy

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def get_tool_calls_and_descriptions_from_snapshot(
"""
break_point = agent_snapshot.break_point.break_point
if not isinstance(break_point, ToolBreakpoint):
raise ValueError("The provided AgentSnapshot does not contain a ToolBreakpoint.")
raise TypeError("The provided AgentSnapshot does not contain a ToolBreakpoint.")

tool_caused_break_point = break_point.tool_name

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ def run(
self,
*,
tool_name: str,
tool_description: str,
tool_params: dict[str, Any],
tool_description: str, # noqa: ARG002
tool_params: dict[str, Any], # noqa: ARG002
tool_call_id: str | None = None,
confirmation_strategy_context: dict[str, Any] | None = None,
confirmation_strategy_context: dict[str, Any] | None = None, # noqa: ARG002
) -> ToolExecutionDecision:
"""
Run the breakpoint confirmation strategy for a given tool and its parameters.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@

from haystack import Document

# See https://github.com/pylint-dev/pylint/issues/9319.
# pylint: disable=unnecessary-ellipsis


class DocumentEmbedder(Protocol):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@

from lazy_imports import LazyImporter

_import_structure = {
"openai": ["OpenAIChatGenerator"],
}
_import_structure = {"openai": ["OpenAIChatGenerator"]}

if TYPE_CHECKING:
from .openai import OpenAIChatGenerator as OpenAIChatGenerator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@

from lazy_imports import LazyImporter

_import_structure = {
"md_header_level_inferrer": ["MarkdownHeaderLevelInferrer"],
}
_import_structure = {"md_header_level_inferrer": ["MarkdownHeaderLevelInferrer"]}

if TYPE_CHECKING:
from .md_header_level_inferrer import MarkdownHeaderLevelInferrer as MarkdownHeaderLevelInferrer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def run(self, documents: list[Document]) -> dict:
logger.warning("No documents provided to process")
return {"documents": []}

logger.debug(f"Inferring and rewriting header levels for {len(documents)} documents")
logger.debug(
"Inferring and rewriting header levels for {num_documents} documents", num_documents=len(documents)
)
processed_docs = [self._process_document(doc) for doc in documents]
return {"documents": processed_docs}

Expand All @@ -69,16 +71,22 @@ def _process_document(self, doc: Document) -> Document:
Document object with rewritten header levels.
"""
if doc.content is None:
logger.warning(f"Document {getattr(doc, 'id', '')} content is None; skipping header level inference.")
logger.warning(
"Document {doc_id} content is None; skipping header level inference.", doc_id=getattr(doc, "id", "")
)
return doc

matches = list(re.finditer(self._header_pattern, doc.content))
if not matches:
logger.info(f"No headers found in document {doc.id}; skipping header level inference.")
logger.info("No headers found in document {doc_id}; skipping header level inference.", doc_id=doc.id)
return doc

modified_text = MarkdownHeaderLevelInferrer._rewrite_headers(doc.content, matches)
logger.info(f"Rewrote {len(matches)} headers with inferred levels in document{doc.id}.")
logger.info(
"Rewrote {num_headers} headers with inferred levels in document{doc_id}.",
num_headers=len(matches),
doc_id=doc.id,
)
return MarkdownHeaderLevelInferrer._build_final_document(doc, modified_text)

@staticmethod
Expand All @@ -99,7 +107,7 @@ def _rewrite_headers(content: str, matches: list[re.Match]) -> str:

# Skip empty headers
if not header_text:
logger.warning(f"Skipping empty header at position {match.start()}")
logger.warning("Skipping empty header at position {start}", start=match.start())
continue

has_content = MarkdownHeaderLevelInferrer._has_content_between_headers(content, matches, i)
Expand Down
12 changes: 10 additions & 2 deletions haystack_experimental/components/retrievers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@
#
# SPDX-License-Identifier: Apache-2.0

from haystack_experimental.components.retrievers.chat_message_retriever import ChatMessageRetriever
import sys
from typing import TYPE_CHECKING

_all_ = ["ChatMessageRetriever"]
from lazy_imports import LazyImporter

_import_structure = {"chat_message_retriever": ["ChatMessageRetriever"]}

if TYPE_CHECKING:
from .chat_message_retriever import ChatMessageRetriever as ChatMessageRetriever
else:
sys.modules[__name__] = LazyImporter(name=__name__, module_file=__file__, import_structure=_import_structure)
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,7 @@ def from_dict(cls, data: dict[str, Any]) -> "ChatMessageRetriever":

@component.output_types(messages=list[ChatMessage])
def run(
self,
chat_history_id: str,
*,
last_k: int | None = None,
current_messages: list[ChatMessage] | None = None,
self, chat_history_id: str, *, last_k: int | None = None, current_messages: list[ChatMessage] | None = None
) -> dict[str, list[ChatMessage]]:
"""
Run the ChatMessageRetriever
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@

from typing import Any, Protocol

# Ellipsis are needed to define the Protocol but pylint complains. See https://github.com/pylint-dev/pylint/issues/9319.
# pylint: disable=unnecessary-ellipsis


class TextRetriever(Protocol):
"""
Expand Down
13 changes: 11 additions & 2 deletions haystack_experimental/components/summarizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@
#
# SPDX-License-Identifier: Apache-2.0

from haystack_experimental.components.summarizers.llm_summarizer import LLMSummarizer
import sys
from typing import TYPE_CHECKING

_all_ = ["Summarizer"]
from lazy_imports import LazyImporter

_import_structure = {"llm_summarizer": ["LLMSummarizer"]}

if TYPE_CHECKING:
from .llm_summarizer import LLMSummarizer as LLMSummarizer

else:
sys.modules[__name__] = LazyImporter(name=__name__, module_file=__file__, import_structure=_import_structure)
14 changes: 3 additions & 11 deletions haystack_experimental/components/summarizers/llm_summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class LLMSummarizer:
```
"""

def __init__( # pylint: disable=too-many-positional-arguments
def __init__(
self,
chat_generator: ChatGenerator,
system_prompt: str | None = "Rewrite this text in summarized form.",
Expand Down Expand Up @@ -207,9 +207,7 @@ def _prepare_text_chunks(self, text, detail, minimum_chunk_size, chunk_delimiter

temp_doc = Document(content=text)
result = self._document_splitter.run(documents=[temp_doc])
text_chunks = [doc.content for doc in result["documents"]]

return text_chunks
return [doc.content for doc in result["documents"]]

def _process_chunks(self, text_chunks, summarize_recursively):
"""
Expand All @@ -236,13 +234,7 @@ def _process_chunks(self, text_chunks, summarize_recursively):

return accumulated_summaries

def summarize(
self,
text: str,
detail: float,
minimum_chunk_size: int,
summarize_recursively: bool = False,
) -> str:
def summarize(self, text: str, detail: float, minimum_chunk_size: int, summarize_recursively: bool = False) -> str:
"""
Summarizes text by splitting it into optimally-sized chunks and processing each with an LLM.

Expand Down
12 changes: 10 additions & 2 deletions haystack_experimental/components/writers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@
#
# SPDX-License-Identifier: Apache-2.0

from haystack_experimental.components.writers.chat_message_writer import ChatMessageWriter
import sys
from typing import TYPE_CHECKING

__all__ = ["ChatMessageWriter"]
from lazy_imports import LazyImporter

_import_structure = {"chat_message_writer": ["ChatMessageWriter"]}

if TYPE_CHECKING:
from .chat_message_writer import ChatMessageWriter as ChatMessageWriter
else:
sys.modules[__name__] = LazyImporter(name=__name__, module_file=__file__, import_structure=_import_structure)
11 changes: 4 additions & 7 deletions haystack_experimental/core/pipeline/breakpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,18 @@

from dataclasses import replace
from datetime import datetime
from typing import TYPE_CHECKING, Any
from typing import Any

from haystack import logging
from haystack.components.agents.agent import _ExecutionContext
from haystack.core.pipeline.utils import _deepcopy_with_exceptions
from haystack.dataclasses.breakpoints import AgentBreakpoint, PipelineSnapshot, PipelineState, ToolBreakpoint
from haystack.human_in_the_loop import ToolExecutionDecision
from haystack.utils.base_serialization import _serialize_value_with_schema
from haystack.utils.misc import _get_output_dir

from haystack_experimental.dataclasses.breakpoints import AgentSnapshot

if TYPE_CHECKING:
from haystack_experimental.components.agents.agent import _ExecutionContext
from haystack_experimental.components.agents.human_in_the_loop import ToolExecutionDecision


logger = logging.getLogger(__name__)


Expand All @@ -27,7 +24,7 @@ def _create_agent_snapshot(
component_visits: dict[str, int],
agent_breakpoint: AgentBreakpoint,
component_inputs: dict[str, Any],
tool_execution_decisions: list["ToolExecutionDecision"] | None = None,
tool_execution_decisions: list[ToolExecutionDecision] | None = None,
) -> AgentSnapshot:
"""
Create a snapshot of the agent's state.
Expand Down
Loading