Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions haystack_experimental/components/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def _initialize_from_snapshot( # type: ignore[override]
tool_execution_decisions=snapshot.tool_execution_decisions,
)

def run( # type: ignore[override] # noqa: PLR0915 PLR0912
def run( # type: ignore[override] # noqa: PLR0915 PLR0912 C901
self,
messages: list[ChatMessage],
streaming_callback: StreamingCallbackT | None = None,
Expand Down Expand Up @@ -381,7 +381,11 @@ def run( # type: ignore[override] # noqa: PLR0915 PLR0912
}
# TODO Probably good to add a warning in runtime checks that BreakpointConfirmationStrategy will take
# precedence over passing a ToolBreakpoint
self._runtime_checks(break_point)
# Support both old signature (break_point) and new signature (break_point, tools)
_runtime_checks_kwargs: dict[str, Any] = {"break_point": break_point}
if "tools" in inspect.signature(HaystackAgent._runtime_checks).parameters:
_runtime_checks_kwargs["tools"] = tools
self._runtime_checks(**_runtime_checks_kwargs)

if snapshot:
exe_context = self._initialize_from_snapshot(
Expand Down Expand Up @@ -573,7 +577,7 @@ def run( # type: ignore[override] # noqa: PLR0915 PLR0912

return result

async def run_async( # type: ignore[override] # noqa: PLR0915
async def run_async( # type: ignore[override] # noqa: PLR0915 PLR0912
self,
messages: list[ChatMessage],
streaming_callback: StreamingCallbackT | None = None,
Expand Down Expand Up @@ -647,7 +651,10 @@ async def run_async( # type: ignore[override] # noqa: PLR0915
"snapshot": snapshot,
**kwargs,
}
self._runtime_checks(break_point)
_runtime_checks_kwargs: dict[str, Any] = {"break_point": break_point}
if "tools" in inspect.signature(HaystackAgent._runtime_checks).parameters:
_runtime_checks_kwargs["tools"] = tools
self._runtime_checks(**_runtime_checks_kwargs)

if snapshot:
exe_context = self._initialize_from_snapshot(
Expand Down
10 changes: 7 additions & 3 deletions test/memory_stores/test_mem0_memory_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def test_add_memories(self, sample_messages, memory_store):
"""Test adding memories successfully."""
store, user_id = memory_store
result = store.add_memories(messages=sample_messages, user_id=user_id)
assert len(result) == 2
# with infer=True (default), two messages are converted to a single memory
assert len(result) == 1

@pytest.mark.skipif(
not os.environ.get("MEM0_API_KEY", None),
Expand Down Expand Up @@ -124,8 +125,11 @@ def test_add_memories_with_metadata(self, memory_store):
reason="Export an env var called MEM0_API_KEY containing the Mem0 API key to run this test.",
)
@pytest.mark.integration
def test_search_memories(self, sample_messages):
"""Test searching memories on previously added memories because the mem0 takes time to index the memory"""
def test_search_memories(self):
"""
Test searching memories on previously added memories because the mem0 takes time to index the memory.
This test points to an existing memory index in mem0 associated to the account we use in CI.
"""
memory_store = Mem0MemoryStore()
# search without query
result = memory_store.search_memories(user_id="haystack_simple_memories")
Expand Down