66from typing import Any , Literal , get_args
77
88from haystack import component , default_from_dict , default_to_dict , logging
9+ from haystack .components .generators .utils import _convert_streaming_chunks_to_chat_message
910from haystack .dataclasses import (
1011 AsyncStreamingCallbackT ,
1112 ChatMessage ,
1213 ChatRole ,
14+ FinishReason ,
1315 ImageContent ,
1416 StreamingCallbackT ,
1517 StreamingChunk ,
2931ImageFormat = Literal ["image/jpeg" , "image/png" ]
3032IMAGE_SUPPORTED_FORMATS : list [ImageFormat ] = list (get_args (ImageFormat ))
3133
34+ # See https://ibm.github.io/watsonx-ai-node-sdk/enums/1_6_x.WatsonXAI.TextChatResultChoiceStream.Constants.FinishReason.html
35+ # for possible finish reasons
36+ FINISH_REASON_MAPPING : dict [str , FinishReason ] = {
37+ "cancelled" : "stop" ,
38+ "error" : "stop" ,
39+ "length" : "length" ,
40+ "stop" : "stop" ,
41+ "time_limit" : "stop" ,
42+ "tool_calls" : "tool_calls" ,
43+ }
44+
3245
3346@component
3447class WatsonxChatGenerator :
@@ -327,6 +340,22 @@ def _prepare_api_call(
327340
328341 return {"messages" : watsonx_messages , "params" : merged_kwargs }
329342
343+ def _convert_chunk_to_streaming_chunk (self , content : str , chunk : dict [str , Any ]) -> StreamingChunk :
344+ """
345+ Convert one Watsonx AI stream-chunk to Haystack StreamingChunk.
346+ """
347+ chunk_meta = {
348+ "model" : self .model ,
349+ "received_at" : datetime .now (timezone .utc ).isoformat (),
350+ }
351+ streaming_chunk = StreamingChunk (
352+ content = content ,
353+ meta = chunk_meta ,
354+ index = chunk ["choices" ][0 ].get ("index" , 0 ),
355+ finish_reason = FINISH_REASON_MAPPING .get (chunk ["choices" ][0 ].get ("finish_reason" )),
356+ )
357+ return streaming_chunk
358+
330359 def _handle_streaming (
331360 self ,
332361 * ,
@@ -350,17 +379,11 @@ def _handle_streaming(
350379
351380 content = chunk ["choices" ][0 ].get ("delta" , {}).get ("content" , "" )
352381 if content :
353- chunk_meta = {
354- "model" : self .model ,
355- "index" : chunk ["choices" ][0 ].get ("index" , 0 ),
356- "finish_reason" : chunk ["choices" ][0 ].get ("finish_reason" ),
357- "received_at" : datetime .now (timezone .utc ).isoformat (),
358- }
359- streaming_chunk = StreamingChunk (content = content , meta = chunk_meta )
382+ streaming_chunk = self ._convert_chunk_to_streaming_chunk (content , chunk )
360383 chunks .append (streaming_chunk )
361384 callback (streaming_chunk )
362385
363- return {"replies" : [self . _convert_streaming_chunks_to_chat_message (chunks )]}
386+ return {"replies" : [_convert_streaming_chunks_to_chat_message (chunks )]}
364387
365388 def _handle_standard (self , api_args : dict [str , Any ]) -> dict [str , list [ChatMessage ]]:
366389 """Handle synchronous standard response."""
@@ -383,35 +406,11 @@ async def _handle_async_streaming(
383406
384407 content = chunk ["choices" ][0 ].get ("delta" , {}).get ("content" , "" )
385408 if content :
386- chunk_meta = {
387- "model" : self .model ,
388- "index" : chunk ["choices" ][0 ].get ("index" , 0 ),
389- "finish_reason" : chunk ["choices" ][0 ].get ("finish_reason" ),
390- "received_at" : datetime .now (timezone .utc ).isoformat (),
391- }
392- streaming_chunk = StreamingChunk (content = content , meta = chunk_meta )
409+ streaming_chunk = self ._convert_chunk_to_streaming_chunk (content , chunk )
393410 chunks .append (streaming_chunk )
394411 await callback (streaming_chunk )
395412
396- return {"replies" : [self ._convert_streaming_chunks_to_chat_message (chunks )]}
397-
398- def _convert_streaming_chunks_to_chat_message (self , chunks : list [StreamingChunk ]) -> ChatMessage :
399- """Convert list of streaming chunks to a single ChatMessage."""
400- if not chunks :
401- return ChatMessage .from_assistant ("" )
402-
403- content = "" .join (chunk .content for chunk in chunks )
404- last_chunk_meta = chunks [- 1 ].meta if chunks else {}
405-
406- return ChatMessage .from_assistant (
407- text = content ,
408- meta = {
409- "model" : self .model ,
410- "finish_reason" : last_chunk_meta .get ("finish_reason" ),
411- "usage" : last_chunk_meta .get ("usage" , {}),
412- "chunks_count" : len (chunks ),
413- },
414- )
413+ return {"replies" : [_convert_streaming_chunks_to_chat_message (chunks )]}
415414
416415 async def _handle_async_standard (self , api_args : dict [str , Any ]) -> dict [str , list [ChatMessage ]]:
417416 """Handle asynchronous standard response."""
0 commit comments