Skip to content

Commit 5f1d9e4

Browse files
authored
fix: fix the conversion of cohere chunks to Haystack streaming chunks (#2968)
* start adding reasoning content support and doing some refactoring * refactoring tests to be more realistic * continue refactoring * more refactoring * Remove redundant tests * fix test * make tests more readable * further simplify tests * Fix when start=True in the streaming chunks * fix finish reason mapping * Remove reasoning for now * make streaming and non-streaming more consistent * fix formatting * fix typing issues * add same changes to async function * formatting * fix integration test
1 parent 0c8ea58 commit 5f1d9e4

4 files changed

Lines changed: 373 additions & 536 deletions

File tree

integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py

Lines changed: 59 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@
3838
ImageFormat = Literal["image/png", "image/jpeg", "image/webp", "image/gif"]
3939
IMAGE_SUPPORTED_FORMATS: list[ImageFormat] = list(get_args(ImageFormat))
4040

41+
FINISH_REASON_MAPPING: dict[str, FinishReason] = {
42+
"COMPLETE": "stop",
43+
"STOP_SEQUENCE": "stop",
44+
"MAX_TOKENS": "length",
45+
"TOOL_CALL": "tool_calls",
46+
}
47+
4148

4249
def _format_tool(tool: Tool) -> dict[str, Any]:
4350
"""
@@ -51,17 +58,11 @@ def _format_tool(tool: Tool) -> dict[str, Any]:
5158
"""
5259
return {
5360
"type": "function",
54-
"function": {
55-
"name": tool.name,
56-
"description": tool.description,
57-
"parameters": tool.parameters,
58-
},
61+
"function": {"name": tool.name, "description": tool.description, "parameters": tool.parameters},
5962
}
6063

6164

62-
def _format_message(
63-
message: ChatMessage,
64-
) -> dict[str, Any]:
65+
def _format_message(message: ChatMessage) -> dict[str, Any]:
6566
"""
6667
Formats a Haystack ChatMessage into Cohere's chat format.
6768
@@ -102,17 +103,10 @@ def _format_message(
102103
{
103104
"id": tool_call.id,
104105
"type": "function",
105-
"function": {
106-
"name": tool_call.tool_name,
107-
"arguments": json.dumps(tool_call.arguments),
108-
},
106+
"function": {"name": tool_call.tool_name, "arguments": json.dumps(tool_call.arguments)},
109107
}
110108
)
111-
return {
112-
"role": "assistant",
113-
"tool_calls": tool_calls,
114-
"tool_plan": message.text if message.text else "",
115-
}
109+
return {"role": "assistant", "tool_calls": tool_calls, "tool_plan": message.text if message.text else ""}
116110

117111
if message.role.value == "user":
118112
if not message.images and not message.text:
@@ -175,42 +169,43 @@ def _parse_response(chat_response: ChatResponse, model: str) -> ChatMessage:
175169
:param model: The name of the model that generated the response.
176170
:return: A Haystack ChatMessage containing the formatted response.
177171
"""
172+
# Extract text content from the response
173+
text_content = ""
174+
if chat_response.message.content:
175+
for content_item in chat_response.message.content:
176+
if content_item.type == "text":
177+
text_content = content_item.text
178+
179+
# Extract tool calls if present in the response
180+
tool_calls = None
178181
if chat_response.message.tool_calls:
179182
tool_calls = []
180183
for tc in chat_response.message.tool_calls:
181184
if tc.function and tc.function.name and tc.function.arguments and isinstance(tc.function.arguments, str):
182185
tool_calls.append(
183-
ToolCall(
184-
id=tc.id,
185-
tool_name=tc.function.name,
186-
arguments=json.loads(tc.function.arguments),
187-
)
186+
ToolCall(id=tc.id, tool_name=tc.function.name, arguments=json.loads(tc.function.arguments))
188187
)
189-
190-
# Create message with tool plan as text and tool calls in the format Haystack expects
191-
tool_plan = chat_response.message.tool_plan or ""
192-
message = ChatMessage.from_assistant(text=tool_plan, tool_calls=tool_calls)
193-
elif chat_response.message.content and hasattr(chat_response.message.content[0], "text"):
194-
message = ChatMessage.from_assistant(chat_response.message.content[0].text)
195-
else:
196-
# Handle the case where neither tool_calls nor content exists
197-
logger.warning(f"Received empty response from Cohere API: {chat_response.message}")
198-
message = ChatMessage.from_assistant("")
199-
188+
# If a tool plan is provided we use that as our text content over the default text content
189+
text_content = chat_response.message.tool_plan or text_content
190+
191+
# Create metadata for the message
192+
resolved_finish_reason = None
193+
if chat_response.finish_reason:
194+
resolved_finish_reason = FINISH_REASON_MAPPING.get(chat_response.finish_reason, chat_response.finish_reason)
195+
base_meta = {
196+
"model": model,
197+
"index": 0,
198+
"finish_reason": resolved_finish_reason,
199+
"citations": chat_response.message.citations,
200+
}
200201
# In V2, token usage is part of the response object, not the message
201-
message._meta.update(
202-
{
203-
"model": model,
204-
"index": 0,
205-
"finish_reason": chat_response.finish_reason,
206-
"citations": chat_response.message.citations,
207-
}
208-
)
209202
if chat_response.usage and chat_response.usage.billed_units:
210-
message._meta["usage"] = {
203+
base_meta["usage"] = {
211204
"prompt_tokens": chat_response.usage.billed_units.input_tokens,
212205
"completion_tokens": chat_response.usage.billed_units.output_tokens,
213206
}
207+
208+
message = ChatMessage.from_assistant(text=text_content, tool_calls=tool_calls, meta=base_meta)
214209
return message
215210

216211

@@ -219,6 +214,7 @@ def _convert_cohere_chunk_to_streaming_chunk(
219214
model: str,
220215
component_info: ComponentInfo | None = None,
221216
global_index: int = 0,
217+
previous_original_chunks: list[StreamedChatResponseV2] | None = None,
222218
) -> StreamingChunk:
223219
"""
224220
Converts a Cohere streaming response chunk to a StreamingChunk.
@@ -237,12 +233,6 @@ def _convert_cohere_chunk_to_streaming_chunk(
237233
:returns:
238234
A StreamingChunk object representing the content of the chunk from the Cohere API.
239235
"""
240-
finish_reason_mapping: dict[str, FinishReason] = {
241-
"COMPLETE": "stop",
242-
"MAX_TOKENS": "length",
243-
"TOOL_CALLS": "tool_calls",
244-
}
245-
246236
# Initialize default values
247237
content = ""
248238
index = global_index
@@ -254,24 +244,23 @@ def _convert_cohere_chunk_to_streaming_chunk(
254244
if chunk.type == "content-delta" and chunk.delta and chunk.delta.message:
255245
if chunk.delta.message and chunk.delta.message.content and chunk.delta.message.content.text is not None:
256246
content = chunk.delta.message.content.text
247+
# If the previous chunk is a content-start chunk, we set start to True for the first content-delta chunk
248+
if previous_original_chunks and previous_original_chunks[-1].type == "content-start":
249+
start = True
257250

258251
elif chunk.type == "tool-plan-delta" and chunk.delta and chunk.delta.message:
259252
if chunk.delta.message and chunk.delta.message.tool_plan is not None:
260253
content = chunk.delta.message.tool_plan
254+
# If the previous chunk is a message-start chunk, we set start to True for the first tool-plan-delta chunk
255+
if previous_original_chunks and previous_original_chunks[-1].type == "message-start":
256+
start = True
261257

262258
elif chunk.type == "tool-call-start" and chunk.delta and chunk.delta.message:
263259
if chunk.delta.message and chunk.delta.message.tool_calls:
264260
tool_call = chunk.delta.message.tool_calls
265261
function = tool_call.function
266262
if function is not None and function.name is not None:
267-
tool_calls = [
268-
ToolCallDelta(
269-
index=global_index,
270-
id=tool_call.id,
271-
tool_name=function.name,
272-
arguments=None,
273-
)
274-
]
263+
tool_calls = [ToolCallDelta(index=global_index, id=tool_call.id, tool_name=function.name)]
275264
start = True # This starts a tool call
276265
if tool_call.id is not None:
277266
meta["tool_call_id"] = tool_call.id
@@ -284,21 +273,11 @@ def _convert_cohere_chunk_to_streaming_chunk(
284273
and chunk.delta.message.tool_calls.function.arguments is not None
285274
):
286275
arguments = chunk.delta.message.tool_calls.function.arguments
287-
tool_calls = [
288-
ToolCallDelta(
289-
index=global_index,
290-
tool_name=None,
291-
arguments=arguments,
292-
)
293-
]
294-
295-
elif chunk.type == "tool-call-end":
296-
# Tool call end doesn't have content, just signals completion
297-
start = True
276+
tool_calls = [ToolCallDelta(index=global_index, arguments=arguments)]
298277

299278
elif chunk.type == "message-end":
300279
finish_reason_raw = getattr(chunk.delta, "finish_reason", None)
301-
finish_reason = finish_reason_mapping.get(finish_reason_raw) if finish_reason_raw else None
280+
finish_reason = FINISH_REASON_MAPPING.get(finish_reason_raw) if finish_reason_raw else None
302281

303282
# The Cohere API is subject to changes in how usage data is returned. We try to support both dict and objects.
304283
usage_data = getattr(chunk.delta, "usage", None)
@@ -346,6 +325,7 @@ def _parse_streaming_response(
346325
347326
Loops through each stream object from Cohere and converts it into a StreamingChunk.
348327
"""
328+
original_chunks: list[StreamedChatResponseV2] = []
349329
chunks: list[StreamingChunk] = []
350330
global_index = 0
351331

@@ -358,11 +338,10 @@ def _parse_streaming_response(
358338
component_info=component_info,
359339
model=model,
360340
global_index=global_index,
341+
previous_original_chunks=original_chunks,
361342
)
362343

363-
if not streaming_chunk:
364-
continue
365-
344+
original_chunks.append(chunk)
366345
chunks.append(streaming_chunk)
367346
streaming_callback(streaming_chunk)
368347

@@ -378,6 +357,7 @@ async def _parse_async_streaming_response(
378357
"""
379358
Parses Cohere's async streaming chat response into a Haystack ChatMessage.
380359
"""
360+
original_chunks: list[StreamedChatResponseV2] = []
381361
chunks: list[StreamingChunk] = []
382362
global_index = 0
383363

@@ -386,11 +366,14 @@ async def _parse_async_streaming_response(
386366
global_index += 1
387367

388368
streaming_chunk = _convert_cohere_chunk_to_streaming_chunk(
389-
chunk=chunk, component_info=component_info, model=model, global_index=global_index
369+
chunk=chunk,
370+
component_info=component_info,
371+
model=model,
372+
global_index=global_index,
373+
previous_original_chunks=original_chunks,
390374
)
391-
if not streaming_chunk:
392-
continue
393375

376+
original_chunks.append(chunk)
394377
chunks.append(streaming_chunk)
395378
await streaming_callback(streaming_chunk)
396379

@@ -638,10 +621,7 @@ def run(
638621
"""
639622

640623
# update generation kwargs by merging with the generation kwargs passed to the run method
641-
generation_kwargs = {
642-
**self.generation_kwargs,
643-
**(generation_kwargs or {}),
644-
}
624+
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
645625

646626
# Handle tools
647627
tools = tools or self.tools
@@ -705,10 +685,7 @@ async def run_async(
705685
"""
706686

707687
# update generation kwargs by merging with the generation kwargs passed to the run method
708-
generation_kwargs = {
709-
**self.generation_kwargs,
710-
**(generation_kwargs or {}),
711-
}
688+
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
712689

713690
# Handle tools
714691
tools = tools or self.tools

integrations/cohere/tests/test_chat_generator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,7 @@ def test_run_image(self):
462462
mock_response = MagicMock()
463463
mock_response.message.content = [MagicMock()]
464464
mock_response.message.content[0].text = "This is a test image response"
465+
mock_response.message.content[0].type = "text"
465466
mock_response.message.tool_calls = None
466467
mock_response.finish_reason = "COMPLETE"
467468
mock_response.usage = None

0 commit comments

Comments
 (0)