3838ImageFormat = Literal ["image/png" , "image/jpeg" , "image/webp" , "image/gif" ]
3939IMAGE_SUPPORTED_FORMATS : list [ImageFormat ] = list (get_args (ImageFormat ))
4040
41- FINISH_REASON_MAPPING : dict [str , FinishReason ] = {
42- "COMPLETE" : "stop" ,
43- "STOP_SEQUENCE" : "stop" ,
44- "MAX_TOKENS" : "length" ,
45- "TOOL_CALL" : "tool_calls" ,
46- }
47-
4841
4942def _format_tool (tool : Tool ) -> dict [str , Any ]:
5043 """
@@ -58,11 +51,17 @@ def _format_tool(tool: Tool) -> dict[str, Any]:
5851 """
5952 return {
6053 "type" : "function" ,
61- "function" : {"name" : tool .name , "description" : tool .description , "parameters" : tool .parameters },
54+ "function" : {
55+ "name" : tool .name ,
56+ "description" : tool .description ,
57+ "parameters" : tool .parameters ,
58+ },
6259 }
6360
6461
65- def _format_message (message : ChatMessage ) -> dict [str , Any ]:
62+ def _format_message (
63+ message : ChatMessage ,
64+ ) -> dict [str , Any ]:
6665 """
6766 Formats a Haystack ChatMessage into Cohere's chat format.
6867
@@ -103,10 +102,17 @@ def _format_message(message: ChatMessage) -> dict[str, Any]:
103102 {
104103 "id" : tool_call .id ,
105104 "type" : "function" ,
106- "function" : {"name" : tool_call .tool_name , "arguments" : json .dumps (tool_call .arguments )},
105+ "function" : {
106+ "name" : tool_call .tool_name ,
107+ "arguments" : json .dumps (tool_call .arguments ),
108+ },
107109 }
108110 )
109- return {"role" : "assistant" , "tool_calls" : tool_calls , "tool_plan" : message .text if message .text else "" }
111+ return {
112+ "role" : "assistant" ,
113+ "tool_calls" : tool_calls ,
114+ "tool_plan" : message .text if message .text else "" ,
115+ }
110116
111117 if message .role .value == "user" :
112118 if not message .images and not message .text :
@@ -169,43 +175,42 @@ def _parse_response(chat_response: ChatResponse, model: str) -> ChatMessage:
169175 :param model: The name of the model that generated the response.
170176 :return: A Haystack ChatMessage containing the formatted response.
171177 """
172- # Extract text content from the response
173- text_content = ""
174- if chat_response .message .content :
175- for content_item in chat_response .message .content :
176- if content_item .type == "text" :
177- text_content = content_item .text
178-
179- # Extract tool calls if present in the response
180- tool_calls = None
181178 if chat_response .message .tool_calls :
182179 tool_calls = []
183180 for tc in chat_response .message .tool_calls :
184181 if tc .function and tc .function .name and tc .function .arguments and isinstance (tc .function .arguments , str ):
185182 tool_calls .append (
186- ToolCall (id = tc .id , tool_name = tc .function .name , arguments = json .loads (tc .function .arguments ))
183+ ToolCall (
184+ id = tc .id ,
185+ tool_name = tc .function .name ,
186+ arguments = json .loads (tc .function .arguments ),
187+ )
187188 )
188- # If a tool plan is provided we use that as our text content over the default text content
189- text_content = chat_response .message .tool_plan or text_content
190-
191- # Create metadata for the message
192- resolved_finish_reason = None
193- if chat_response .finish_reason :
194- resolved_finish_reason = FINISH_REASON_MAPPING .get (chat_response .finish_reason , chat_response .finish_reason )
195- base_meta = {
196- "model" : model ,
197- "index" : 0 ,
198- "finish_reason" : resolved_finish_reason ,
199- "citations" : chat_response .message .citations ,
200- }
189+
190+ # Create message with tool plan as text and tool calls in the format Haystack expects
191+ tool_plan = chat_response .message .tool_plan or ""
192+ message = ChatMessage .from_assistant (text = tool_plan , tool_calls = tool_calls )
193+ elif chat_response .message .content and hasattr (chat_response .message .content [0 ], "text" ):
194+ message = ChatMessage .from_assistant (chat_response .message .content [0 ].text )
195+ else :
196+ # Handle the case where neither tool_calls nor content exists
197+ logger .warning (f"Received empty response from Cohere API: { chat_response .message } " )
198+ message = ChatMessage .from_assistant ("" )
199+
201200 # In V2, token usage is part of the response object, not the message
201+ message ._meta .update (
202+ {
203+ "model" : model ,
204+ "index" : 0 ,
205+ "finish_reason" : chat_response .finish_reason ,
206+ "citations" : chat_response .message .citations ,
207+ }
208+ )
202209 if chat_response .usage and chat_response .usage .billed_units :
203- base_meta ["usage" ] = {
210+ message . _meta ["usage" ] = {
204211 "prompt_tokens" : chat_response .usage .billed_units .input_tokens ,
205212 "completion_tokens" : chat_response .usage .billed_units .output_tokens ,
206213 }
207-
208- message = ChatMessage .from_assistant (text = text_content , tool_calls = tool_calls , meta = base_meta )
209214 return message
210215
211216
@@ -214,7 +219,6 @@ def _convert_cohere_chunk_to_streaming_chunk(
214219 model : str ,
215220 component_info : ComponentInfo | None = None ,
216221 global_index : int = 0 ,
217- previous_original_chunks : list [StreamedChatResponseV2 ] | None = None ,
218222) -> StreamingChunk :
219223 """
220224 Converts a Cohere streaming response chunk to a StreamingChunk.
@@ -233,6 +237,12 @@ def _convert_cohere_chunk_to_streaming_chunk(
233237 :returns:
234238 A StreamingChunk object representing the content of the chunk from the Cohere API.
235239 """
240+ finish_reason_mapping : dict [str , FinishReason ] = {
241+ "COMPLETE" : "stop" ,
242+ "MAX_TOKENS" : "length" ,
243+ "TOOL_CALLS" : "tool_calls" ,
244+ }
245+
236246 # Initialize default values
237247 content = ""
238248 index = global_index
@@ -244,23 +254,24 @@ def _convert_cohere_chunk_to_streaming_chunk(
244254 if chunk .type == "content-delta" and chunk .delta and chunk .delta .message :
245255 if chunk .delta .message and chunk .delta .message .content and chunk .delta .message .content .text is not None :
246256 content = chunk .delta .message .content .text
247- # If the previous chunk is a content-start chunk, we set start to True for the first content-delta chunk
248- if previous_original_chunks and previous_original_chunks [- 1 ].type == "content-start" :
249- start = True
250257
251258 elif chunk .type == "tool-plan-delta" and chunk .delta and chunk .delta .message :
252259 if chunk .delta .message and chunk .delta .message .tool_plan is not None :
253260 content = chunk .delta .message .tool_plan
254- # If the previous chunk is a message-start chunk, we set start to True for the first tool-plan-delta chunk
255- if previous_original_chunks and previous_original_chunks [- 1 ].type == "message-start" :
256- start = True
257261
258262 elif chunk .type == "tool-call-start" and chunk .delta and chunk .delta .message :
259263 if chunk .delta .message and chunk .delta .message .tool_calls :
260264 tool_call = chunk .delta .message .tool_calls
261265 function = tool_call .function
262266 if function is not None and function .name is not None :
263- tool_calls = [ToolCallDelta (index = global_index , id = tool_call .id , tool_name = function .name )]
267+ tool_calls = [
268+ ToolCallDelta (
269+ index = global_index ,
270+ id = tool_call .id ,
271+ tool_name = function .name ,
272+ arguments = None ,
273+ )
274+ ]
264275 start = True # This starts a tool call
265276 if tool_call .id is not None :
266277 meta ["tool_call_id" ] = tool_call .id
@@ -273,11 +284,21 @@ def _convert_cohere_chunk_to_streaming_chunk(
273284 and chunk .delta .message .tool_calls .function .arguments is not None
274285 ):
275286 arguments = chunk .delta .message .tool_calls .function .arguments
276- tool_calls = [ToolCallDelta (index = global_index , arguments = arguments )]
287+ tool_calls = [
288+ ToolCallDelta (
289+ index = global_index ,
290+ tool_name = None ,
291+ arguments = arguments ,
292+ )
293+ ]
294+
295+ elif chunk .type == "tool-call-end" :
296+ # Tool call end doesn't have content, just signals completion
297+ start = True
277298
278299 elif chunk .type == "message-end" :
279300 finish_reason_raw = getattr (chunk .delta , "finish_reason" , None )
280- finish_reason = FINISH_REASON_MAPPING .get (finish_reason_raw ) if finish_reason_raw else None
301+ finish_reason = finish_reason_mapping .get (finish_reason_raw ) if finish_reason_raw else None
281302
282303 # The Cohere API is subject to changes in how usage data is returned. We try to support both dict and objects.
283304 usage_data = getattr (chunk .delta , "usage" , None )
@@ -325,7 +346,6 @@ def _parse_streaming_response(
325346
326347 Loops through each stream object from Cohere and converts it into a StreamingChunk.
327348 """
328- original_chunks : list [StreamedChatResponseV2 ] = []
329349 chunks : list [StreamingChunk ] = []
330350 global_index = 0
331351
@@ -338,10 +358,11 @@ def _parse_streaming_response(
338358 component_info = component_info ,
339359 model = model ,
340360 global_index = global_index ,
341- previous_original_chunks = original_chunks ,
342361 )
343362
344- original_chunks .append (chunk )
363+ if not streaming_chunk :
364+ continue
365+
345366 chunks .append (streaming_chunk )
346367 streaming_callback (streaming_chunk )
347368
@@ -357,7 +378,6 @@ async def _parse_async_streaming_response(
357378 """
358379 Parses Cohere's async streaming chat response into a Haystack ChatMessage.
359380 """
360- original_chunks : list [StreamedChatResponseV2 ] = []
361381 chunks : list [StreamingChunk ] = []
362382 global_index = 0
363383
@@ -366,14 +386,11 @@ async def _parse_async_streaming_response(
366386 global_index += 1
367387
368388 streaming_chunk = _convert_cohere_chunk_to_streaming_chunk (
369- chunk = chunk ,
370- component_info = component_info ,
371- model = model ,
372- global_index = global_index ,
373- previous_original_chunks = original_chunks ,
389+ chunk = chunk , component_info = component_info , model = model , global_index = global_index
374390 )
391+ if not streaming_chunk :
392+ continue
375393
376- original_chunks .append (chunk )
377394 chunks .append (streaming_chunk )
378395 await streaming_callback (streaming_chunk )
379396
@@ -491,7 +508,7 @@ def __init__(
491508 * ,
492509 timeout : float | None = None ,
493510 max_retries : int | None = None ,
494- ):
511+ ) -> None :
495512 """
496513 Initialize the CohereChatGenerator instance.
497514
@@ -621,7 +638,10 @@ def run(
621638 """
622639
623640 # update generation kwargs by merging with the generation kwargs passed to the run method
624- generation_kwargs = {** self .generation_kwargs , ** (generation_kwargs or {})}
641+ generation_kwargs = {
642+ ** self .generation_kwargs ,
643+ ** (generation_kwargs or {}),
644+ }
625645
626646 # Handle tools
627647 tools = tools or self .tools
@@ -685,7 +705,10 @@ async def run_async(
685705 """
686706
687707 # update generation kwargs by merging with the generation kwargs passed to the run method
688- generation_kwargs = {** self .generation_kwargs , ** (generation_kwargs or {})}
708+ generation_kwargs = {
709+ ** self .generation_kwargs ,
710+ ** (generation_kwargs or {}),
711+ }
689712
690713 # Handle tools
691714 tools = tools or self .tools
0 commit comments