Skip to content

Commit 88ce981

Browse files
mpangrazzijulian-risch
authored andcommitted
Fix OpenAIGenerator and OpenAIChatGenerator to allow wrapped streaming objects usage (#9304)
* Fix for handling wrapped ChatCompletion instances in streaming (used by tools like weave) * Add release note * Applied same fix to OpenAIGenerator ; Refactoring ; Update release note * Fix integration test error after refactoring
1 parent 8d109a9 commit 88ce981

6 files changed

Lines changed: 137 additions & 22 deletions

File tree

haystack/components/generators/chat/openai.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -274,10 +274,7 @@ def run(
274274
**api_args
275275
)
276276

277-
is_streaming = isinstance(chat_completion, Stream)
278-
assert is_streaming or streaming_callback is None
279-
280-
if is_streaming:
277+
if streaming_callback is not None:
281278
completions = self._handle_stream_response(
282279
chat_completion, # type: ignore
283280
streaming_callback, # type: ignore
@@ -353,10 +350,7 @@ async def run_async(
353350
AsyncStream[ChatCompletionChunk], ChatCompletion
354351
] = await self.async_client.chat.completions.create(**api_args)
355352

356-
is_streaming = isinstance(chat_completion, AsyncStream)
357-
assert is_streaming or streaming_callback is None
358-
359-
if is_streaming:
353+
if streaming_callback is not None:
360354
completions = await self._handle_async_stream_response(
361355
chat_completion, # type: ignore
362356
streaming_callback, # type: ignore

haystack/components/generators/openai.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -224,22 +224,25 @@ def run(
224224
)
225225

226226
completions: List[ChatMessage] = []
227-
if isinstance(completion, Stream):
227+
if streaming_callback is not None:
228228
num_responses = generation_kwargs.pop("n", 1)
229229
if num_responses > 1:
230230
raise ValueError("Cannot stream multiple responses, please set n=1.")
231231
chunks: List[StreamingChunk] = []
232-
completion_chunk: Optional[ChatCompletionChunk] = None
233-
234-
# pylint: disable=not-an-iterable
235-
for completion_chunk in completion:
236-
if completion_chunk.choices and streaming_callback:
237-
chunk_delta: StreamingChunk = self._build_chunk(completion_chunk)
238-
chunks.append(chunk_delta)
239-
streaming_callback(chunk_delta) # invoke callback with the chunk_delta
240-
# Makes type checkers happy
241-
assert completion_chunk is not None
242-
completions = [self._create_message_from_chunks(completion_chunk, chunks)]
232+
last_chunk: Optional[ChatCompletionChunk] = None
233+
234+
for chunk in completion:
235+
if isinstance(chunk, ChatCompletionChunk):
236+
last_chunk = chunk
237+
238+
if chunk.choices:
239+
chunk_delta: StreamingChunk = self._build_chunk(chunk)
240+
chunks.append(chunk_delta)
241+
streaming_callback(chunk_delta)
242+
243+
assert last_chunk is not None
244+
245+
completions = [self._create_message_from_chunks(last_chunk, chunks)]
243246
elif isinstance(completion, ChatCompletion):
244247
completions = [self._build_message(completion, choice) for choice in completion.choices]
245248

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
fixes:
3+
- |
4+
Fix an issue where OpenAIChatGenerator and OpenAIGenerator were not properly handling wrapped streaming responses from tools like Weave.

test/components/generators/chat/test_openai.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
22
#
33
# SPDX-License-Identifier: Apache-2.0
4-
from unittest.mock import patch
4+
from unittest.mock import patch, MagicMock, AsyncMock
55
import pytest
66

77

@@ -364,6 +364,40 @@ def streaming_callback(chunk: StreamingChunk) -> None:
364364
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
365365
assert "Hello" in response["replies"][0].text # see openai_mock_chat_completion_chunk
366366

367+
def test_run_with_wrapped_stream_simulation(self, chat_messages, openai_mock_stream):
368+
streaming_callback_called = False
369+
370+
def streaming_callback(chunk: StreamingChunk) -> None:
371+
nonlocal streaming_callback_called
372+
streaming_callback_called = True
373+
assert isinstance(chunk, StreamingChunk)
374+
375+
chunk = ChatCompletionChunk(
376+
id="id",
377+
model="gpt-4",
378+
object="chat.completion.chunk",
379+
choices=[chat_completion_chunk.Choice(index=0, delta=chat_completion_chunk.ChoiceDelta(content="Hello"))],
380+
created=int(datetime.now().timestamp()),
381+
)
382+
383+
# Here we wrap the OpenAI stream in a MagicMock
384+
# This is to simulate the behavior of some tools like Weave (https://github.com/wandb/weave)
385+
# which wrap the OpenAI stream in their own stream
386+
wrapped_openai_stream = MagicMock()
387+
wrapped_openai_stream.__iter__.return_value = iter([chunk])
388+
389+
component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
390+
391+
with patch.object(
392+
component.client.chat.completions, "create", return_value=wrapped_openai_stream
393+
) as mock_create:
394+
response = component.run(chat_messages, streaming_callback=streaming_callback)
395+
396+
mock_create.assert_called_once()
397+
assert streaming_callback_called
398+
assert "replies" in response
399+
assert "Hello" in response["replies"][0].text
400+
367401
def test_check_abnormal_completions(self, caplog):
368402
caplog.set_level(logging.INFO)
369403
component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))

test/components/generators/chat/test_openai_async.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
22
#
33
# SPDX-License-Identifier: Apache-2.0
4-
from unittest.mock import AsyncMock, patch
4+
from unittest.mock import AsyncMock, patch, MagicMock
55

66
from openai import AsyncOpenAI, OpenAIError
77
import pytest
@@ -365,3 +365,42 @@ async def test_live_run_with_tools_async(self, tools):
365365
assert tool_call.tool_name == "weather"
366366
assert tool_call.arguments == {"city": "Paris"}
367367
assert message.meta["finish_reason"] == "tool_calls"
368+
369+
@pytest.mark.asyncio
370+
async def test_run_with_wrapped_stream_simulation_async(self, chat_messages, openai_mock_stream_async):
371+
streaming_callback_called = False
372+
373+
async def streaming_callback(chunk: StreamingChunk) -> None:
374+
nonlocal streaming_callback_called
375+
streaming_callback_called = True
376+
assert isinstance(chunk, StreamingChunk)
377+
378+
chunk = ChatCompletionChunk(
379+
id="id",
380+
model="gpt-4",
381+
object="chat.completion.chunk",
382+
choices=[chat_completion_chunk.Choice(index=0, delta=chat_completion_chunk.ChoiceDelta(content="Hello"))],
383+
created=int(datetime.now().timestamp()),
384+
)
385+
386+
# Here we wrap the OpenAI async stream in an AsyncMock
387+
# This is to simulate the behavior of some tools like Weave (https://github.com/wandb/weave)
388+
# which wrap the OpenAI async stream in their own stream
389+
wrapped_openai_async_stream = AsyncMock()
390+
wrapped_openai_async_stream.__aiter__.return_value = iter([chunk])
391+
392+
component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
393+
394+
# Patch the async client's create method
395+
with patch.object(
396+
component.async_client.chat.completions,
397+
"create",
398+
return_value=wrapped_openai_async_stream,
399+
new_callable=AsyncMock,
400+
) as mock_create:
401+
response = await component.run_async(chat_messages, streaming_callback=streaming_callback)
402+
403+
mock_create.assert_called_once()
404+
assert streaming_callback_called
405+
assert "replies" in response
406+
assert "Hello" in response["replies"][0].text

test/components/generators/test_openai.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
import pytest
1010
from openai import OpenAIError
11+
from openai.types.chat import ChatCompletionChunk, chat_completion_chunk
12+
from unittest.mock import MagicMock, patch
1113

1214
from haystack.components.generators import OpenAIGenerator
1315
from haystack.components.generators.utils import print_streaming_chunk
@@ -359,3 +361,42 @@ def __call__(self, chunk: StreamingChunk) -> None:
359361

360362
assert callback.counter > 1
361363
assert "Paris" in callback.responses
364+
365+
def test_run_with_wrapped_stream_simulation(self, openai_mock_stream):
366+
streaming_callback_called = False
367+
368+
def streaming_callback(chunk: StreamingChunk) -> None:
369+
nonlocal streaming_callback_called
370+
streaming_callback_called = True
371+
assert isinstance(chunk, StreamingChunk)
372+
373+
chunk = ChatCompletionChunk(
374+
id="id",
375+
model="gpt-4",
376+
object="chat.completion.chunk",
377+
choices=[
378+
chat_completion_chunk.Choice(
379+
index=0, delta=chat_completion_chunk.ChoiceDelta(content="Hello"), finish_reason="stop"
380+
)
381+
],
382+
created=int(datetime.now().timestamp()),
383+
)
384+
385+
# Here we wrap the OpenAI stream in a MagicMock
386+
# This is to simulate the behavior of some tools like Weave (https://github.com/wandb/weave)
387+
# which wrap the OpenAI stream in their own stream
388+
wrapped_openai_stream = MagicMock()
389+
wrapped_openai_stream.__iter__.return_value = iter([chunk])
390+
391+
component = OpenAIGenerator(api_key=Secret.from_token("test-api-key"))
392+
393+
with patch.object(
394+
component.client.chat.completions, "create", return_value=wrapped_openai_stream
395+
) as mock_create:
396+
response = component.run(prompt="test prompt", streaming_callback=streaming_callback)
397+
398+
mock_create.assert_called_once()
399+
assert streaming_callback_called
400+
assert "replies" in response
401+
assert "Hello" in response["replies"][0]
402+
assert response["meta"][0]["finish_reason"] == "stop"

0 commit comments

Comments
 (0)