Skip to content

Commit 2daf520

Browse files
authored
fix: fix Agent implementation and Mem0MemoryStore test (#450)
1 parent 5ef578a commit 2daf520

File tree

2 files changed

+18
-7
lines changed

2 files changed

+18
-7
lines changed

haystack_experimental/components/agents/agent.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def _initialize_from_snapshot( # type: ignore[override]
308308
tool_execution_decisions=snapshot.tool_execution_decisions,
309309
)
310310

311-
def run( # type: ignore[override] # noqa: PLR0915 PLR0912
311+
def run( # type: ignore[override] # noqa: PLR0915 PLR0912 C901
312312
self,
313313
messages: list[ChatMessage],
314314
streaming_callback: StreamingCallbackT | None = None,
@@ -381,7 +381,11 @@ def run( # type: ignore[override] # noqa: PLR0915 PLR0912
381381
}
382382
# TODO Probably good to add a warning in runtime checks that BreakpointConfirmationStrategy will take
383383
# precedence over passing a ToolBreakpoint
384-
self._runtime_checks(break_point)
384+
# Support both old signature (break_point) and new signature (break_point, tools)
385+
_runtime_checks_kwargs: dict[str, Any] = {"break_point": break_point}
386+
if "tools" in inspect.signature(HaystackAgent._runtime_checks).parameters:
387+
_runtime_checks_kwargs["tools"] = tools
388+
self._runtime_checks(**_runtime_checks_kwargs)
385389

386390
if snapshot:
387391
exe_context = self._initialize_from_snapshot(
@@ -573,7 +577,7 @@ def run( # type: ignore[override] # noqa: PLR0915 PLR0912
573577

574578
return result
575579

576-
async def run_async( # type: ignore[override] # noqa: PLR0915
580+
async def run_async( # type: ignore[override] # noqa: PLR0915 PLR0912
577581
self,
578582
messages: list[ChatMessage],
579583
streaming_callback: StreamingCallbackT | None = None,
@@ -647,7 +651,10 @@ async def run_async( # type: ignore[override] # noqa: PLR0915
647651
"snapshot": snapshot,
648652
**kwargs,
649653
}
650-
self._runtime_checks(break_point)
654+
_runtime_checks_kwargs: dict[str, Any] = {"break_point": break_point}
655+
if "tools" in inspect.signature(HaystackAgent._runtime_checks).parameters:
656+
_runtime_checks_kwargs["tools"] = tools
657+
self._runtime_checks(**_runtime_checks_kwargs)
651658

652659
if snapshot:
653660
exe_context = self._initialize_from_snapshot(

test/memory_stores/test_mem0_memory_store.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ def test_add_memories(self, sample_messages, memory_store):
9292
"""Test adding memories successfully."""
9393
store, user_id = memory_store
9494
result = store.add_memories(messages=sample_messages, user_id=user_id)
95-
assert len(result) == 2
95+
# with infer=True (default), two messages are converted to a single memory
96+
assert len(result) == 1
9697

9798
@pytest.mark.skipif(
9899
not os.environ.get("MEM0_API_KEY", None),
@@ -124,8 +125,11 @@ def test_add_memories_with_metadata(self, memory_store):
124125
reason="Export an env var called MEM0_API_KEY containing the Mem0 API key to run this test.",
125126
)
126127
@pytest.mark.integration
127-
def test_search_memories(self, sample_messages):
128-
"""Test searching memories on previously added memories because the mem0 takes time to index the memory"""
128+
def test_search_memories(self):
129+
"""
130+
Test searching memories on previously added memories because the mem0 takes time to index the memory.
131+
This test points to an existing memory index in mem0 associated to the account we use in CI.
132+
"""
129133
memory_store = Mem0MemoryStore()
130134
# search without query
131135
result = memory_store.search_memories(user_id="haystack_simple_memories")

0 commit comments

Comments
 (0)