diff --git a/haystack_experimental/components/agents/agent.py b/haystack_experimental/components/agents/agent.py index 73f265e7..680936a1 100644 --- a/haystack_experimental/components/agents/agent.py +++ b/haystack_experimental/components/agents/agent.py @@ -308,7 +308,7 @@ def _initialize_from_snapshot( # type: ignore[override] tool_execution_decisions=snapshot.tool_execution_decisions, ) - def run( # type: ignore[override] # noqa: PLR0915 PLR0912 + def run( # type: ignore[override] # noqa: PLR0915 PLR0912 C901 self, messages: list[ChatMessage], streaming_callback: StreamingCallbackT | None = None, @@ -381,7 +381,11 @@ def run( # type: ignore[override] # noqa: PLR0915 PLR0912 } # TODO Probably good to add a warning in runtime checks that BreakpointConfirmationStrategy will take # precedence over passing a ToolBreakpoint - self._runtime_checks(break_point) + # Support both old signature (break_point) and new signature (break_point, tools) + _runtime_checks_kwargs: dict[str, Any] = {"break_point": break_point} + if "tools" in inspect.signature(HaystackAgent._runtime_checks).parameters: + _runtime_checks_kwargs["tools"] = tools + self._runtime_checks(**_runtime_checks_kwargs) if snapshot: exe_context = self._initialize_from_snapshot( @@ -573,7 +577,7 @@ def run( # type: ignore[override] # noqa: PLR0915 PLR0912 return result - async def run_async( # type: ignore[override] # noqa: PLR0915 + async def run_async( # type: ignore[override] # noqa: PLR0915 PLR0912 self, messages: list[ChatMessage], streaming_callback: StreamingCallbackT | None = None, @@ -647,7 +651,10 @@ async def run_async( # type: ignore[override] # noqa: PLR0915 "snapshot": snapshot, **kwargs, } - self._runtime_checks(break_point) + _runtime_checks_kwargs: dict[str, Any] = {"break_point": break_point} + if "tools" in inspect.signature(HaystackAgent._runtime_checks).parameters: + _runtime_checks_kwargs["tools"] = tools + self._runtime_checks(**_runtime_checks_kwargs) if snapshot: exe_context = self._initialize_from_snapshot( diff --git a/test/memory_stores/test_mem0_memory_store.py b/test/memory_stores/test_mem0_memory_store.py index df2dd0cc..6b762910 100644 --- a/test/memory_stores/test_mem0_memory_store.py +++ b/test/memory_stores/test_mem0_memory_store.py @@ -92,7 +92,8 @@ def test_add_memories(self, sample_messages, memory_store): """Test adding memories successfully.""" store, user_id = memory_store result = store.add_memories(messages=sample_messages, user_id=user_id) - assert len(result) == 2 + # with infer=True (default), two messages are converted to a single memory + assert len(result) == 1 @pytest.mark.skipif( not os.environ.get("MEM0_API_KEY", None), @@ -124,8 +125,11 @@ def test_add_memories_with_metadata(self, memory_store): reason="Export an env var called MEM0_API_KEY containing the Mem0 API key to run this test.", ) @pytest.mark.integration - def test_search_memories(self, sample_messages): - """Test searching memories on previously added memories because the mem0 takes time to index the memory""" + def test_search_memories(self): + """ + Test searching memories on previously added memories because the mem0 takes time to index the memory. + This test points to an existing memory index in mem0 associated to the account we use in CI. + """ memory_store = Mem0MemoryStore() # search without query result = memory_store.search_memories(user_id="haystack_simple_memories")