Skip to content

Commit 68b17d8

Browse files
anakin87mpangrazzi
authored andcommitted
fix: make Agent run_async work with async streaming_callback (#9824)
* fix Agent streaming_callback requires_async * add tests * fix * relnote
1 parent 134998a commit 68b17d8

3 files changed

Lines changed: 47 additions & 8 deletions

File tree

haystack/components/agents/agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -663,13 +663,13 @@ async def run_async(
663663

664664
if snapshot:
665665
exe_context = self._initialize_from_snapshot(
666-
snapshot=snapshot, streaming_callback=streaming_callback, requires_async=False, tools=tools
666+
snapshot=snapshot, streaming_callback=streaming_callback, requires_async=True, tools=tools
667667
)
668668
else:
669669
exe_context = self._initialize_fresh_execution(
670670
messages=messages,
671671
streaming_callback=streaming_callback,
672-
requires_async=False,
672+
requires_async=True,
673673
system_prompt=system_prompt,
674674
tools=tools,
675675
**kwargs,
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
fixes:
3+
- |
4+
Fix Agent `run_async` method to correctly handle async streaming callbacks.
5+
This previously triggered errors due to a bug.

test/components/agents/test_agent.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,13 @@
2929
from haystack.utils import Secret, serialize_callable
3030

3131

32-
def streaming_callback_for_serde(chunk: StreamingChunk):
32+
def sync_streaming_callback(chunk: StreamingChunk) -> None:
33+
"""A synchronous streaming callback."""
34+
pass
35+
36+
37+
async def async_streaming_callback(chunk: StreamingChunk) -> None:
38+
"""An asynchronous streaming callback."""
3339
pass
3440

3541

@@ -501,18 +507,16 @@ def test_serde_with_streaming_callback(self, weather_tool, component_tool, monke
501507
monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key")
502508
generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY"))
503509
agent = Agent(
504-
chat_generator=generator,
505-
tools=[weather_tool, component_tool],
506-
streaming_callback=streaming_callback_for_serde,
510+
chat_generator=generator, tools=[weather_tool, component_tool], streaming_callback=sync_streaming_callback
507511
)
508512

509513
serialized_agent = agent.to_dict()
510514

511515
init_parameters = serialized_agent["init_parameters"]
512-
assert init_parameters["streaming_callback"] == "test_agent.streaming_callback_for_serde"
516+
assert init_parameters["streaming_callback"] == "test_agent.sync_streaming_callback"
513517

514518
deserialized_agent = Agent.from_dict(serialized_agent)
515-
assert deserialized_agent.streaming_callback is streaming_callback_for_serde
519+
assert deserialized_agent.streaming_callback is sync_streaming_callback
516520

517521
def test_exit_conditions_validation(self, weather_tool, component_tool, monkeypatch):
518522
monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key")
@@ -928,6 +932,36 @@ def streaming_callback(chunk: StreamingChunk) -> None:
928932
assert result["last_message"] is not None
929933
assert streaming_callback_called
930934

935+
@pytest.mark.asyncio
936+
async def test_run_async_with_async_streaming_callback(self, weather_tool):
937+
chat_generator = MockChatGenerator()
938+
agent = Agent(chat_generator=chat_generator, tools=[weather_tool], streaming_callback=async_streaming_callback)
939+
agent.warm_up()
940+
941+
# This should not raise any exception
942+
result = await agent.run_async([ChatMessage.from_user("Hello")])
943+
944+
assert "messages" in result
945+
assert len(result["messages"]) == 2
946+
assert result["messages"][1].text == "Hello from run_async"
947+
948+
def test_run_with_async_streaming_callback_fails(self, weather_tool):
949+
chat_generator = MockChatGenerator()
950+
agent = Agent(chat_generator=chat_generator, tools=[weather_tool], streaming_callback=async_streaming_callback)
951+
agent.warm_up()
952+
953+
with pytest.raises(ValueError, match="The init callback cannot be a coroutine"):
954+
agent.run([ChatMessage.from_user("Hello")])
955+
956+
@pytest.mark.asyncio
957+
async def test_run_async_with_sync_streaming_callback_fails(self, weather_tool):
958+
chat_generator = MockChatGenerator()
959+
agent = Agent(chat_generator=chat_generator, tools=[weather_tool], streaming_callback=sync_streaming_callback)
960+
agent.warm_up()
961+
962+
with pytest.raises(ValueError, match="The init callback must be async compatible"):
963+
await agent.run_async([ChatMessage.from_user("Hello")])
964+
931965

932966
class TestAgentTracing:
933967
def test_agent_tracing_span_run(self, caplog, monkeypatch, weather_tool):

0 commit comments

Comments
 (0)