|
1 | 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> |
2 | 2 | # |
3 | 3 | # SPDX-License-Identifier: Apache-2.0 |
4 | | -from unittest.mock import AsyncMock, patch |
| 4 | +from unittest.mock import AsyncMock, patch, MagicMock |
5 | 5 |
|
6 | 6 | from openai import AsyncOpenAI, OpenAIError |
7 | 7 | import pytest |
@@ -365,3 +365,42 @@ async def test_live_run_with_tools_async(self, tools): |
365 | 365 | assert tool_call.tool_name == "weather" |
366 | 366 | assert tool_call.arguments == {"city": "Paris"} |
367 | 367 | assert message.meta["finish_reason"] == "tool_calls" |
| 368 | + |
| 369 | + @pytest.mark.asyncio |
| 370 | + async def test_run_with_wrapped_stream_simulation_async(self, chat_messages, openai_mock_stream_async): |
| 371 | + streaming_callback_called = False |
| 372 | + |
| 373 | + async def streaming_callback(chunk: StreamingChunk) -> None: |
| 374 | + nonlocal streaming_callback_called |
| 375 | + streaming_callback_called = True |
| 376 | + assert isinstance(chunk, StreamingChunk) |
| 377 | + |
| 378 | + chunk = ChatCompletionChunk( |
| 379 | + id="id", |
| 380 | + model="gpt-4", |
| 381 | + object="chat.completion.chunk", |
| 382 | + choices=[chat_completion_chunk.Choice(index=0, delta=chat_completion_chunk.ChoiceDelta(content="Hello"))], |
| 383 | + created=int(datetime.now().timestamp()), |
| 384 | + ) |
| 385 | + |
| 386 | + # Here we wrap the OpenAI async stream in an AsyncMock |
| 387 | + # This is to simulate the behavior of some tools like Weave (https://github.com/wandb/weave) |
| 388 | + # which wrap the OpenAI async stream in their own stream |
| 389 | + wrapped_openai_async_stream = AsyncMock() |
| 390 | + wrapped_openai_async_stream.__aiter__.return_value = iter([chunk]) |
| 391 | + |
| 392 | + component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key")) |
| 393 | + |
| 394 | + # Patch the async client's create method |
| 395 | + with patch.object( |
| 396 | + component.async_client.chat.completions, |
| 397 | + "create", |
| 398 | + return_value=wrapped_openai_async_stream, |
| 399 | + new_callable=AsyncMock, |
| 400 | + ) as mock_create: |
| 401 | + response = await component.run_async(chat_messages, streaming_callback=streaming_callback) |
| 402 | + |
| 403 | + mock_create.assert_called_once() |
| 404 | + assert streaming_callback_called |
| 405 | + assert "replies" in response |
| 406 | + assert "Hello" in response["replies"][0].text |
0 commit comments