3838ImageFormat = Literal ["image/png" , "image/jpeg" , "image/webp" , "image/gif" ]
3939IMAGE_SUPPORTED_FORMATS : list [ImageFormat ] = list (get_args (ImageFormat ))
4040
41+ FINISH_REASON_MAPPING : dict [str , FinishReason ] = {
42+ "COMPLETE" : "stop" ,
43+ "STOP_SEQUENCE" : "stop" ,
44+ "MAX_TOKENS" : "length" ,
45+ "TOOL_CALL" : "tool_calls" ,
46+ }
47+
4148
4249def _format_tool (tool : Tool ) -> dict [str , Any ]:
4350 """
@@ -51,17 +58,11 @@ def _format_tool(tool: Tool) -> dict[str, Any]:
5158 """
5259 return {
5360 "type" : "function" ,
54- "function" : {
55- "name" : tool .name ,
56- "description" : tool .description ,
57- "parameters" : tool .parameters ,
58- },
61+ "function" : {"name" : tool .name , "description" : tool .description , "parameters" : tool .parameters },
5962 }
6063
6164
62- def _format_message (
63- message : ChatMessage ,
64- ) -> dict [str , Any ]:
65+ def _format_message (message : ChatMessage ) -> dict [str , Any ]:
6566 """
6667 Formats a Haystack ChatMessage into Cohere's chat format.
6768
@@ -102,17 +103,10 @@ def _format_message(
102103 {
103104 "id" : tool_call .id ,
104105 "type" : "function" ,
105- "function" : {
106- "name" : tool_call .tool_name ,
107- "arguments" : json .dumps (tool_call .arguments ),
108- },
106+ "function" : {"name" : tool_call .tool_name , "arguments" : json .dumps (tool_call .arguments )},
109107 }
110108 )
111- return {
112- "role" : "assistant" ,
113- "tool_calls" : tool_calls ,
114- "tool_plan" : message .text if message .text else "" ,
115- }
109+ return {"role" : "assistant" , "tool_calls" : tool_calls , "tool_plan" : message .text if message .text else "" }
116110
117111 if message .role .value == "user" :
118112 if not message .images and not message .text :
@@ -175,42 +169,43 @@ def _parse_response(chat_response: ChatResponse, model: str) -> ChatMessage:
175169 :param model: The name of the model that generated the response.
176170 :return: A Haystack ChatMessage containing the formatted response.
177171 """
172+ # Extract text content from the response
173+ text_content = ""
174+ if chat_response .message .content :
175+ for content_item in chat_response .message .content :
176+ if content_item .type == "text" :
177+ text_content = content_item .text
178+
179+ # Extract tool calls if present in the response
180+ tool_calls = None
178181 if chat_response .message .tool_calls :
179182 tool_calls = []
180183 for tc in chat_response .message .tool_calls :
181184 if tc .function and tc .function .name and tc .function .arguments and isinstance (tc .function .arguments , str ):
182185 tool_calls .append (
183- ToolCall (
184- id = tc .id ,
185- tool_name = tc .function .name ,
186- arguments = json .loads (tc .function .arguments ),
187- )
186+ ToolCall (id = tc .id , tool_name = tc .function .name , arguments = json .loads (tc .function .arguments ))
188187 )
189-
190- # Create message with tool plan as text and tool calls in the format Haystack expects
191- tool_plan = chat_response .message .tool_plan or ""
192- message = ChatMessage .from_assistant (text = tool_plan , tool_calls = tool_calls )
193- elif chat_response .message .content and hasattr (chat_response .message .content [0 ], "text" ):
194- message = ChatMessage .from_assistant (chat_response .message .content [0 ].text )
195- else :
196- # Handle the case where neither tool_calls nor content exists
197- logger .warning (f"Received empty response from Cohere API: { chat_response .message } " )
198- message = ChatMessage .from_assistant ("" )
199-
188+ # If a tool plan is provided we use that as our text content over the default text content
189+ text_content = chat_response .message .tool_plan or text_content
190+
191+ # Create metadata for the message
192+ resolved_finish_reason = None
193+ if chat_response .finish_reason :
194+ resolved_finish_reason = FINISH_REASON_MAPPING .get (chat_response .finish_reason , chat_response .finish_reason )
195+ base_meta = {
196+ "model" : model ,
197+ "index" : 0 ,
198+ "finish_reason" : resolved_finish_reason ,
199+ "citations" : chat_response .message .citations ,
200+ }
200201 # In V2, token usage is part of the response object, not the message
201- message ._meta .update (
202- {
203- "model" : model ,
204- "index" : 0 ,
205- "finish_reason" : chat_response .finish_reason ,
206- "citations" : chat_response .message .citations ,
207- }
208- )
209202 if chat_response .usage and chat_response .usage .billed_units :
210- message . _meta ["usage" ] = {
203+ base_meta ["usage" ] = {
211204 "prompt_tokens" : chat_response .usage .billed_units .input_tokens ,
212205 "completion_tokens" : chat_response .usage .billed_units .output_tokens ,
213206 }
207+
208+ message = ChatMessage .from_assistant (text = text_content , tool_calls = tool_calls , meta = base_meta )
214209 return message
215210
216211
@@ -219,6 +214,7 @@ def _convert_cohere_chunk_to_streaming_chunk(
219214 model : str ,
220215 component_info : ComponentInfo | None = None ,
221216 global_index : int = 0 ,
217+ previous_original_chunks : list [StreamedChatResponseV2 ] | None = None ,
222218) -> StreamingChunk :
223219 """
224220 Converts a Cohere streaming response chunk to a StreamingChunk.
@@ -237,12 +233,6 @@ def _convert_cohere_chunk_to_streaming_chunk(
237233 :returns:
238234 A StreamingChunk object representing the content of the chunk from the Cohere API.
239235 """
240- finish_reason_mapping : dict [str , FinishReason ] = {
241- "COMPLETE" : "stop" ,
242- "MAX_TOKENS" : "length" ,
243- "TOOL_CALLS" : "tool_calls" ,
244- }
245-
246236 # Initialize default values
247237 content = ""
248238 index = global_index
@@ -254,24 +244,23 @@ def _convert_cohere_chunk_to_streaming_chunk(
254244 if chunk .type == "content-delta" and chunk .delta and chunk .delta .message :
255245 if chunk .delta .message and chunk .delta .message .content and chunk .delta .message .content .text is not None :
256246 content = chunk .delta .message .content .text
247+ # If the previous chunk is a content-start chunk, we set start to True for the first content-delta chunk
248+ if previous_original_chunks and previous_original_chunks [- 1 ].type == "content-start" :
249+ start = True
257250
258251 elif chunk .type == "tool-plan-delta" and chunk .delta and chunk .delta .message :
259252 if chunk .delta .message and chunk .delta .message .tool_plan is not None :
260253 content = chunk .delta .message .tool_plan
254+ # If the previous chunk is a message-start chunk, we set start to True for the first tool-plan-delta chunk
255+ if previous_original_chunks and previous_original_chunks [- 1 ].type == "message-start" :
256+ start = True
261257
262258 elif chunk .type == "tool-call-start" and chunk .delta and chunk .delta .message :
263259 if chunk .delta .message and chunk .delta .message .tool_calls :
264260 tool_call = chunk .delta .message .tool_calls
265261 function = tool_call .function
266262 if function is not None and function .name is not None :
267- tool_calls = [
268- ToolCallDelta (
269- index = global_index ,
270- id = tool_call .id ,
271- tool_name = function .name ,
272- arguments = None ,
273- )
274- ]
263+ tool_calls = [ToolCallDelta (index = global_index , id = tool_call .id , tool_name = function .name )]
275264 start = True # This starts a tool call
276265 if tool_call .id is not None :
277266 meta ["tool_call_id" ] = tool_call .id
@@ -284,21 +273,11 @@ def _convert_cohere_chunk_to_streaming_chunk(
284273 and chunk .delta .message .tool_calls .function .arguments is not None
285274 ):
286275 arguments = chunk .delta .message .tool_calls .function .arguments
287- tool_calls = [
288- ToolCallDelta (
289- index = global_index ,
290- tool_name = None ,
291- arguments = arguments ,
292- )
293- ]
294-
295- elif chunk .type == "tool-call-end" :
296- # Tool call end doesn't have content, just signals completion
297- start = True
276+ tool_calls = [ToolCallDelta (index = global_index , arguments = arguments )]
298277
299278 elif chunk .type == "message-end" :
300279 finish_reason_raw = getattr (chunk .delta , "finish_reason" , None )
301- finish_reason = finish_reason_mapping .get (finish_reason_raw ) if finish_reason_raw else None
280+ finish_reason = FINISH_REASON_MAPPING .get (finish_reason_raw ) if finish_reason_raw else None
302281
303282 # The Cohere API is subject to changes in how usage data is returned. We try to support both dict and objects.
304283 usage_data = getattr (chunk .delta , "usage" , None )
@@ -346,6 +325,7 @@ def _parse_streaming_response(
346325
347326 Loops through each stream object from Cohere and converts it into a StreamingChunk.
348327 """
328+ original_chunks : list [StreamedChatResponseV2 ] = []
349329 chunks : list [StreamingChunk ] = []
350330 global_index = 0
351331
@@ -358,11 +338,10 @@ def _parse_streaming_response(
358338 component_info = component_info ,
359339 model = model ,
360340 global_index = global_index ,
341+ previous_original_chunks = original_chunks ,
361342 )
362343
363- if not streaming_chunk :
364- continue
365-
344+ original_chunks .append (chunk )
366345 chunks .append (streaming_chunk )
367346 streaming_callback (streaming_chunk )
368347
@@ -378,6 +357,7 @@ async def _parse_async_streaming_response(
378357 """
379358 Parses Cohere's async streaming chat response into a Haystack ChatMessage.
380359 """
360+ original_chunks : list [StreamedChatResponseV2 ] = []
381361 chunks : list [StreamingChunk ] = []
382362 global_index = 0
383363
@@ -386,11 +366,14 @@ async def _parse_async_streaming_response(
386366 global_index += 1
387367
388368 streaming_chunk = _convert_cohere_chunk_to_streaming_chunk (
389- chunk = chunk , component_info = component_info , model = model , global_index = global_index
369+ chunk = chunk ,
370+ component_info = component_info ,
371+ model = model ,
372+ global_index = global_index ,
373+ previous_original_chunks = original_chunks ,
390374 )
391- if not streaming_chunk :
392- continue
393375
376+ original_chunks .append (chunk )
394377 chunks .append (streaming_chunk )
395378 await streaming_callback (streaming_chunk )
396379
@@ -638,10 +621,7 @@ def run(
638621 """
639622
640623 # update generation kwargs by merging with the generation kwargs passed to the run method
641- generation_kwargs = {
642- ** self .generation_kwargs ,
643- ** (generation_kwargs or {}),
644- }
624+ generation_kwargs = {** self .generation_kwargs , ** (generation_kwargs or {})}
645625
646626 # Handle tools
647627 tools = tools or self .tools
@@ -705,10 +685,7 @@ async def run_async(
705685 """
706686
707687 # update generation kwargs by merging with the generation kwargs passed to the run method
708- generation_kwargs = {
709- ** self .generation_kwargs ,
710- ** (generation_kwargs or {}),
711- }
688+ generation_kwargs = {** self .generation_kwargs , ** (generation_kwargs or {})}
712689
713690 # Handle tools
714691 tools = tools or self .tools
0 commit comments