Skip to content

Commit ddb7b5c

Browse files
committed
feat: Update WatsonXChatGenerator to use the StreamingChunk fields
1 parent 24935e4 commit ddb7b5c

2 files changed

Lines changed: 40 additions & 41 deletions

File tree

integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py

Lines changed: 33 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
from typing import Any, Literal, get_args
77

88
from haystack import component, default_from_dict, default_to_dict, logging
9+
from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message
910
from haystack.dataclasses import (
1011
AsyncStreamingCallbackT,
1112
ChatMessage,
1213
ChatRole,
14+
FinishReason,
1315
ImageContent,
1416
StreamingCallbackT,
1517
StreamingChunk,
@@ -29,6 +31,17 @@
2931
ImageFormat = Literal["image/jpeg", "image/png"]
3032
IMAGE_SUPPORTED_FORMATS: list[ImageFormat] = list(get_args(ImageFormat))
3133

34+
# See https://ibm.github.io/watsonx-ai-node-sdk/enums/1_6_x.WatsonXAI.TextChatResultChoiceStream.Constants.FinishReason.html
35+
# for possible finish reasons
36+
FINISH_REASON_MAPPING: dict[str, FinishReason] = {
37+
"cancelled": "stop",
38+
"error": "stop",
39+
"length": "length",
40+
"stop": "stop",
41+
"time_limit": "stop",
42+
"tool_calls": "tool_calls",
43+
}
44+
3245

3346
@component
3447
class WatsonxChatGenerator:
@@ -327,6 +340,22 @@ def _prepare_api_call(
327340

328341
return {"messages": watsonx_messages, "params": merged_kwargs}
329342

343+
def _convert_chunk_to_streaming_chunk(self, content: str, chunk: dict[str, Any]) -> StreamingChunk:
344+
"""
345+
Convert one Watsonx AI stream-chunk to Haystack StreamingChunk.
346+
"""
347+
chunk_meta = {
348+
"model": self.model,
349+
"received_at": datetime.now(timezone.utc).isoformat(),
350+
}
351+
streaming_chunk = StreamingChunk(
352+
content=content,
353+
meta=chunk_meta,
354+
index=chunk["choices"][0].get("index", 0),
355+
finish_reason=FINISH_REASON_MAPPING.get(chunk["choices"][0].get("finish_reason")),
356+
)
357+
return streaming_chunk
358+
330359
def _handle_streaming(
331360
self,
332361
*,
@@ -350,17 +379,11 @@ def _handle_streaming(
350379

351380
content = chunk["choices"][0].get("delta", {}).get("content", "")
352381
if content:
353-
chunk_meta = {
354-
"model": self.model,
355-
"index": chunk["choices"][0].get("index", 0),
356-
"finish_reason": chunk["choices"][0].get("finish_reason"),
357-
"received_at": datetime.now(timezone.utc).isoformat(),
358-
}
359-
streaming_chunk = StreamingChunk(content=content, meta=chunk_meta)
382+
streaming_chunk = self._convert_chunk_to_streaming_chunk(content, chunk)
360383
chunks.append(streaming_chunk)
361384
callback(streaming_chunk)
362385

363-
return {"replies": [self._convert_streaming_chunks_to_chat_message(chunks)]}
386+
return {"replies": [_convert_streaming_chunks_to_chat_message(chunks)]}
364387

365388
def _handle_standard(self, api_args: dict[str, Any]) -> dict[str, list[ChatMessage]]:
366389
"""Handle synchronous standard response."""
@@ -383,35 +406,11 @@ async def _handle_async_streaming(
383406

384407
content = chunk["choices"][0].get("delta", {}).get("content", "")
385408
if content:
386-
chunk_meta = {
387-
"model": self.model,
388-
"index": chunk["choices"][0].get("index", 0),
389-
"finish_reason": chunk["choices"][0].get("finish_reason"),
390-
"received_at": datetime.now(timezone.utc).isoformat(),
391-
}
392-
streaming_chunk = StreamingChunk(content=content, meta=chunk_meta)
409+
streaming_chunk = self._convert_chunk_to_streaming_chunk(content, chunk)
393410
chunks.append(streaming_chunk)
394411
await callback(streaming_chunk)
395412

396-
return {"replies": [self._convert_streaming_chunks_to_chat_message(chunks)]}
397-
398-
def _convert_streaming_chunks_to_chat_message(self, chunks: list[StreamingChunk]) -> ChatMessage:
399-
"""Convert list of streaming chunks to a single ChatMessage."""
400-
if not chunks:
401-
return ChatMessage.from_assistant("")
402-
403-
content = "".join(chunk.content for chunk in chunks)
404-
last_chunk_meta = chunks[-1].meta if chunks else {}
405-
406-
return ChatMessage.from_assistant(
407-
text=content,
408-
meta={
409-
"model": self.model,
410-
"finish_reason": last_chunk_meta.get("finish_reason"),
411-
"usage": last_chunk_meta.get("usage", {}),
412-
"chunks_count": len(chunks),
413-
},
414-
)
413+
return {"replies": [_convert_streaming_chunks_to_chat_message(chunks)]}
415414

416415
async def _handle_async_standard(self, api_args: dict[str, Any]) -> dict[str, list[ChatMessage]]:
417416
"""Handle asynchronous standard response."""

integrations/watsonx/tests/test_chat_generator.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def mock_watsonx(self, monkeypatch):
4141
{
4242
"message": {"content": "This is a generated response", "role": "assistant"},
4343
"index": 0,
44-
"finish_reason": "completed",
44+
"finish_reason": "stop",
4545
}
4646
],
4747
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
@@ -53,7 +53,7 @@ def mock_watsonx(self, monkeypatch):
5353
{
5454
"message": {"content": "Async generated response", "role": "assistant"},
5555
"index": 0,
56-
"finish_reason": "completed",
56+
"finish_reason": "stop",
5757
}
5858
],
5959
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
@@ -62,7 +62,7 @@ def mock_watsonx(self, monkeypatch):
6262
mock_model_instance.chat_stream = MagicMock(
6363
return_value=[
6464
{"choices": [{"delta": {"content": "Streaming"}, "index": 0, "finish_reason": None}]},
65-
{"choices": [{"delta": {"content": " response"}, "index": 0, "finish_reason": "completed"}]},
65+
{"choices": [{"delta": {"content": " response"}, "index": 0, "finish_reason": "stop"}]},
6666
]
6767
)
6868

@@ -85,7 +85,7 @@ async def __anext__(self):
8585
elif self._count == 2:
8686
return {
8787
"choices": [
88-
{"delta": {"content": " response"}, "finish_reason": "completed", "index": 0}
88+
{"delta": {"content": " response"}, "finish_reason": "stop", "index": 0}
8989
]
9090
}
9191
else:
@@ -227,7 +227,7 @@ def test_run_single_message(self, mock_watsonx):
227227

228228
assert len(result["replies"]) == 1
229229
assert result["replies"][0].text == "This is a generated response"
230-
assert result["replies"][0].meta["finish_reason"] == "completed"
230+
assert result["replies"][0].meta["finish_reason"] == "stop"
231231

232232
mock_watsonx["model_instance"].chat.assert_called_once_with(
233233
messages=[{"role": "user", "content": "Test prompt"}], params={}
@@ -273,7 +273,7 @@ def test_run_with_streaming(self, mock_watsonx):
273273

274274
assert len(result["replies"]) == 1
275275
assert result["replies"][0].text == "Streaming response"
276-
assert result["replies"][0].meta["finish_reason"] == "completed"
276+
assert result["replies"][0].meta["finish_reason"] == "stop"
277277

278278
def test_run_with_empty_messages(self, mock_watsonx):
279279
generator = WatsonxChatGenerator(
@@ -338,7 +338,7 @@ async def test_run_async_single_message(self, mock_watsonx):
338338

339339
assert len(result["replies"]) == 1
340340
assert result["replies"][0].text == "Async generated response"
341-
assert result["replies"][0].meta["finish_reason"] == "completed"
341+
assert result["replies"][0].meta["finish_reason"] == "stop"
342342

343343
@pytest.mark.asyncio
344344
async def test_run_async_streaming(self, mock_watsonx):

0 commit comments

Comments
 (0)