|
29 | 29 | from haystack.utils import Secret, serialize_callable |
30 | 30 |
|
31 | 31 |
|
32 | | -def streaming_callback_for_serde(chunk: StreamingChunk): |
| 32 | +def sync_streaming_callback(chunk: StreamingChunk) -> None: |
| 33 | + """A synchronous streaming callback.""" |
| 34 | + pass |
| 35 | + |
| 36 | + |
| 37 | +async def async_streaming_callback(chunk: StreamingChunk) -> None: |
| 38 | + """An asynchronous streaming callback.""" |
33 | 39 | pass |
34 | 40 |
|
35 | 41 |
|
@@ -501,18 +507,16 @@ def test_serde_with_streaming_callback(self, weather_tool, component_tool, monke |
501 | 507 | monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key") |
502 | 508 | generator = OpenAIChatGenerator(api_key=Secret.from_env_var("FAKE_OPENAI_KEY")) |
503 | 509 | agent = Agent( |
504 | | - chat_generator=generator, |
505 | | - tools=[weather_tool, component_tool], |
506 | | - streaming_callback=streaming_callback_for_serde, |
| 510 | + chat_generator=generator, tools=[weather_tool, component_tool], streaming_callback=sync_streaming_callback |
507 | 511 | ) |
508 | 512 |
|
509 | 513 | serialized_agent = agent.to_dict() |
510 | 514 |
|
511 | 515 | init_parameters = serialized_agent["init_parameters"] |
512 | | - assert init_parameters["streaming_callback"] == "test_agent.streaming_callback_for_serde" |
| 516 | + assert init_parameters["streaming_callback"] == "test_agent.sync_streaming_callback" |
513 | 517 |
|
514 | 518 | deserialized_agent = Agent.from_dict(serialized_agent) |
515 | | - assert deserialized_agent.streaming_callback is streaming_callback_for_serde |
| 519 | + assert deserialized_agent.streaming_callback is sync_streaming_callback |
516 | 520 |
|
517 | 521 | def test_exit_conditions_validation(self, weather_tool, component_tool, monkeypatch): |
518 | 522 | monkeypatch.setenv("FAKE_OPENAI_KEY", "fake-key") |
@@ -928,6 +932,36 @@ def streaming_callback(chunk: StreamingChunk) -> None: |
928 | 932 | assert result["last_message"] is not None |
929 | 933 | assert streaming_callback_called |
930 | 934 |
|
| 935 | + @pytest.mark.asyncio |
| 936 | + async def test_run_async_with_async_streaming_callback(self, weather_tool): |
| 937 | + chat_generator = MockChatGenerator() |
| 938 | + agent = Agent(chat_generator=chat_generator, tools=[weather_tool], streaming_callback=async_streaming_callback) |
| 939 | + agent.warm_up() |
| 940 | + |
| 941 | + # This should not raise any exception |
| 942 | + result = await agent.run_async([ChatMessage.from_user("Hello")]) |
| 943 | + |
| 944 | + assert "messages" in result |
| 945 | + assert len(result["messages"]) == 2 |
| 946 | + assert result["messages"][1].text == "Hello from run_async" |
| 947 | + |
| 948 | + def test_run_with_async_streaming_callback_fails(self, weather_tool): |
| 949 | + chat_generator = MockChatGenerator() |
| 950 | + agent = Agent(chat_generator=chat_generator, tools=[weather_tool], streaming_callback=async_streaming_callback) |
| 951 | + agent.warm_up() |
| 952 | + |
| 953 | + with pytest.raises(ValueError, match="The init callback cannot be a coroutine"): |
| 954 | + agent.run([ChatMessage.from_user("Hello")]) |
| 955 | + |
| 956 | + @pytest.mark.asyncio |
| 957 | + async def test_run_async_with_sync_streaming_callback_fails(self, weather_tool): |
| 958 | + chat_generator = MockChatGenerator() |
| 959 | + agent = Agent(chat_generator=chat_generator, tools=[weather_tool], streaming_callback=sync_streaming_callback) |
| 960 | + agent.warm_up() |
| 961 | + |
| 962 | + with pytest.raises(ValueError, match="The init callback must be async compatible"): |
| 963 | + await agent.run_async([ChatMessage.from_user("Hello")]) |
| 964 | + |
931 | 965 |
|
932 | 966 | class TestAgentTracing: |
933 | 967 | def test_agent_tracing_span_run(self, caplog, monkeypatch, weather_tool): |
|
0 commit comments