Skip to content

Commit b866f08

Browse files
julian-rischclaude
andcommitted
chore: enable ANN ruff ruleset for cohere integration
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent fc82ec4 commit b866f08

11 files changed

Lines changed: 545 additions & 387 deletions

File tree

integrations/cohere/CHANGELOG.md

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,5 @@
11
# Changelog
22

3-
## [integrations/cohere-v8.0.1] - 2026-03-17
4-
5-
### 🐛 Bug Fixes
6-
7-
- Fix the conversion of cohere chunks to Haystack streaming chunks (#2968)
8-
9-
103
## [integrations/cohere-v8.0.0] - 2026-03-11
114

125
### 🚀 Features

integrations/cohere/pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ line-length = 120
8484
[tool.ruff.lint]
8585
select = [
8686
"A",
87+
"ANN",
8788
"ARG",
8889
"B",
8990
"C",
@@ -111,6 +112,7 @@ select = [
111112
ignore = [
112113
# Allow non-abstract empty methods in abstract base classes
113114
"B027",
115+
"ANN401", # Allow Any - used legitimately for dynamic types and SDK boundaries
114116
# Ignore checks for possible passwords
115117
"S105",
116118
"S106",
@@ -134,7 +136,7 @@ ban-relative-imports = "parents"
134136

135137
[tool.ruff.lint.per-file-ignores]
136138
# Tests can use magic values, assertions, and relative imports
137-
"tests/**/*" = ["PLR2004", "S101", "TID252"]
139+
"tests/**/*" = ["PLR2004", "S101", "TID252", "ANN"]
138140

139141
[tool.coverage.run]
140142
source = ["haystack_integrations"]

integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(
4949
meta_fields_to_embed: list[str] | None = None,
5050
embedding_separator: str = "\n",
5151
embedding_type: EmbeddingTypes | None = None,
52-
):
52+
) -> None:
5353
"""
5454
:param api_key: the Cohere API key.
5555
:param model: the name of the model to use. Supported Models are:

integrations/cohere/src/haystack_integrations/components/embedders/cohere/embedding_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class EmbeddingTypes(Enum):
2121
BINARY = "binary"
2222
UBINARY = "ubinary"
2323

24-
def __str__(self):
24+
def __str__(self) -> str:
2525
return self.value
2626

2727
@staticmethod

integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(
4141
truncate: str = "END",
4242
timeout: float = 120.0,
4343
embedding_type: EmbeddingTypes | None = None,
44-
):
44+
) -> None:
4545
"""
4646
:param api_key: the Cohere API key.
4747
:param model: the name of the model to use. Supported Models are:

integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py

Lines changed: 83 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,6 @@
3838
ImageFormat = Literal["image/png", "image/jpeg", "image/webp", "image/gif"]
3939
IMAGE_SUPPORTED_FORMATS: list[ImageFormat] = list(get_args(ImageFormat))
4040

41-
FINISH_REASON_MAPPING: dict[str, FinishReason] = {
42-
"COMPLETE": "stop",
43-
"STOP_SEQUENCE": "stop",
44-
"MAX_TOKENS": "length",
45-
"TOOL_CALL": "tool_calls",
46-
}
47-
4841

4942
def _format_tool(tool: Tool) -> dict[str, Any]:
5043
"""
@@ -58,11 +51,17 @@ def _format_tool(tool: Tool) -> dict[str, Any]:
5851
"""
5952
return {
6053
"type": "function",
61-
"function": {"name": tool.name, "description": tool.description, "parameters": tool.parameters},
54+
"function": {
55+
"name": tool.name,
56+
"description": tool.description,
57+
"parameters": tool.parameters,
58+
},
6259
}
6360

6461

65-
def _format_message(message: ChatMessage) -> dict[str, Any]:
62+
def _format_message(
63+
message: ChatMessage,
64+
) -> dict[str, Any]:
6665
"""
6766
Formats a Haystack ChatMessage into Cohere's chat format.
6867
@@ -103,10 +102,17 @@ def _format_message(message: ChatMessage) -> dict[str, Any]:
103102
{
104103
"id": tool_call.id,
105104
"type": "function",
106-
"function": {"name": tool_call.tool_name, "arguments": json.dumps(tool_call.arguments)},
105+
"function": {
106+
"name": tool_call.tool_name,
107+
"arguments": json.dumps(tool_call.arguments),
108+
},
107109
}
108110
)
109-
return {"role": "assistant", "tool_calls": tool_calls, "tool_plan": message.text if message.text else ""}
111+
return {
112+
"role": "assistant",
113+
"tool_calls": tool_calls,
114+
"tool_plan": message.text if message.text else "",
115+
}
110116

111117
if message.role.value == "user":
112118
if not message.images and not message.text:
@@ -169,43 +175,42 @@ def _parse_response(chat_response: ChatResponse, model: str) -> ChatMessage:
169175
:param model: The name of the model that generated the response.
170176
:return: A Haystack ChatMessage containing the formatted response.
171177
"""
172-
# Extract text content from the response
173-
text_content = ""
174-
if chat_response.message.content:
175-
for content_item in chat_response.message.content:
176-
if content_item.type == "text":
177-
text_content = content_item.text
178-
179-
# Extract tool calls if present in the response
180-
tool_calls = None
181178
if chat_response.message.tool_calls:
182179
tool_calls = []
183180
for tc in chat_response.message.tool_calls:
184181
if tc.function and tc.function.name and tc.function.arguments and isinstance(tc.function.arguments, str):
185182
tool_calls.append(
186-
ToolCall(id=tc.id, tool_name=tc.function.name, arguments=json.loads(tc.function.arguments))
183+
ToolCall(
184+
id=tc.id,
185+
tool_name=tc.function.name,
186+
arguments=json.loads(tc.function.arguments),
187+
)
187188
)
188-
# If a tool plan is provided we use that as our text content over the default text content
189-
text_content = chat_response.message.tool_plan or text_content
190-
191-
# Create metadata for the message
192-
resolved_finish_reason = None
193-
if chat_response.finish_reason:
194-
resolved_finish_reason = FINISH_REASON_MAPPING.get(chat_response.finish_reason, chat_response.finish_reason)
195-
base_meta = {
196-
"model": model,
197-
"index": 0,
198-
"finish_reason": resolved_finish_reason,
199-
"citations": chat_response.message.citations,
200-
}
189+
190+
# Create message with tool plan as text and tool calls in the format Haystack expects
191+
tool_plan = chat_response.message.tool_plan or ""
192+
message = ChatMessage.from_assistant(text=tool_plan, tool_calls=tool_calls)
193+
elif chat_response.message.content and hasattr(chat_response.message.content[0], "text"):
194+
message = ChatMessage.from_assistant(chat_response.message.content[0].text)
195+
else:
196+
# Handle the case where neither tool_calls nor content exists
197+
logger.warning(f"Received empty response from Cohere API: {chat_response.message}")
198+
message = ChatMessage.from_assistant("")
199+
201200
# In V2, token usage is part of the response object, not the message
201+
message._meta.update(
202+
{
203+
"model": model,
204+
"index": 0,
205+
"finish_reason": chat_response.finish_reason,
206+
"citations": chat_response.message.citations,
207+
}
208+
)
202209
if chat_response.usage and chat_response.usage.billed_units:
203-
base_meta["usage"] = {
210+
message._meta["usage"] = {
204211
"prompt_tokens": chat_response.usage.billed_units.input_tokens,
205212
"completion_tokens": chat_response.usage.billed_units.output_tokens,
206213
}
207-
208-
message = ChatMessage.from_assistant(text=text_content, tool_calls=tool_calls, meta=base_meta)
209214
return message
210215

211216

@@ -214,7 +219,6 @@ def _convert_cohere_chunk_to_streaming_chunk(
214219
model: str,
215220
component_info: ComponentInfo | None = None,
216221
global_index: int = 0,
217-
previous_original_chunks: list[StreamedChatResponseV2] | None = None,
218222
) -> StreamingChunk:
219223
"""
220224
Converts a Cohere streaming response chunk to a StreamingChunk.
@@ -233,6 +237,12 @@ def _convert_cohere_chunk_to_streaming_chunk(
233237
:returns:
234238
A StreamingChunk object representing the content of the chunk from the Cohere API.
235239
"""
240+
finish_reason_mapping: dict[str, FinishReason] = {
241+
"COMPLETE": "stop",
242+
"MAX_TOKENS": "length",
243+
"TOOL_CALLS": "tool_calls",
244+
}
245+
236246
# Initialize default values
237247
content = ""
238248
index = global_index
@@ -244,23 +254,24 @@ def _convert_cohere_chunk_to_streaming_chunk(
244254
if chunk.type == "content-delta" and chunk.delta and chunk.delta.message:
245255
if chunk.delta.message and chunk.delta.message.content and chunk.delta.message.content.text is not None:
246256
content = chunk.delta.message.content.text
247-
# If the previous chunk is a content-start chunk, we set start to True for the first content-delta chunk
248-
if previous_original_chunks and previous_original_chunks[-1].type == "content-start":
249-
start = True
250257

251258
elif chunk.type == "tool-plan-delta" and chunk.delta and chunk.delta.message:
252259
if chunk.delta.message and chunk.delta.message.tool_plan is not None:
253260
content = chunk.delta.message.tool_plan
254-
# If the previous chunk is a message-start chunk, we set start to True for the first tool-plan-delta chunk
255-
if previous_original_chunks and previous_original_chunks[-1].type == "message-start":
256-
start = True
257261

258262
elif chunk.type == "tool-call-start" and chunk.delta and chunk.delta.message:
259263
if chunk.delta.message and chunk.delta.message.tool_calls:
260264
tool_call = chunk.delta.message.tool_calls
261265
function = tool_call.function
262266
if function is not None and function.name is not None:
263-
tool_calls = [ToolCallDelta(index=global_index, id=tool_call.id, tool_name=function.name)]
267+
tool_calls = [
268+
ToolCallDelta(
269+
index=global_index,
270+
id=tool_call.id,
271+
tool_name=function.name,
272+
arguments=None,
273+
)
274+
]
264275
start = True # This starts a tool call
265276
if tool_call.id is not None:
266277
meta["tool_call_id"] = tool_call.id
@@ -273,11 +284,21 @@ def _convert_cohere_chunk_to_streaming_chunk(
273284
and chunk.delta.message.tool_calls.function.arguments is not None
274285
):
275286
arguments = chunk.delta.message.tool_calls.function.arguments
276-
tool_calls = [ToolCallDelta(index=global_index, arguments=arguments)]
287+
tool_calls = [
288+
ToolCallDelta(
289+
index=global_index,
290+
tool_name=None,
291+
arguments=arguments,
292+
)
293+
]
294+
295+
elif chunk.type == "tool-call-end":
296+
# Tool call end doesn't have content, just signals completion
297+
start = True
277298

278299
elif chunk.type == "message-end":
279300
finish_reason_raw = getattr(chunk.delta, "finish_reason", None)
280-
finish_reason = FINISH_REASON_MAPPING.get(finish_reason_raw) if finish_reason_raw else None
301+
finish_reason = finish_reason_mapping.get(finish_reason_raw) if finish_reason_raw else None
281302

282303
# The Cohere API is subject to changes in how usage data is returned. We try to support both dict and objects.
283304
usage_data = getattr(chunk.delta, "usage", None)
@@ -325,7 +346,6 @@ def _parse_streaming_response(
325346
326347
Loops through each stream object from Cohere and converts it into a StreamingChunk.
327348
"""
328-
original_chunks: list[StreamedChatResponseV2] = []
329349
chunks: list[StreamingChunk] = []
330350
global_index = 0
331351

@@ -338,10 +358,11 @@ def _parse_streaming_response(
338358
component_info=component_info,
339359
model=model,
340360
global_index=global_index,
341-
previous_original_chunks=original_chunks,
342361
)
343362

344-
original_chunks.append(chunk)
363+
if not streaming_chunk:
364+
continue
365+
345366
chunks.append(streaming_chunk)
346367
streaming_callback(streaming_chunk)
347368

@@ -357,7 +378,6 @@ async def _parse_async_streaming_response(
357378
"""
358379
Parses Cohere's async streaming chat response into a Haystack ChatMessage.
359380
"""
360-
original_chunks: list[StreamedChatResponseV2] = []
361381
chunks: list[StreamingChunk] = []
362382
global_index = 0
363383

@@ -366,14 +386,11 @@ async def _parse_async_streaming_response(
366386
global_index += 1
367387

368388
streaming_chunk = _convert_cohere_chunk_to_streaming_chunk(
369-
chunk=chunk,
370-
component_info=component_info,
371-
model=model,
372-
global_index=global_index,
373-
previous_original_chunks=original_chunks,
389+
chunk=chunk, component_info=component_info, model=model, global_index=global_index
374390
)
391+
if not streaming_chunk:
392+
continue
375393

376-
original_chunks.append(chunk)
377394
chunks.append(streaming_chunk)
378395
await streaming_callback(streaming_chunk)
379396

@@ -491,7 +508,7 @@ def __init__(
491508
*,
492509
timeout: float | None = None,
493510
max_retries: int | None = None,
494-
):
511+
) -> None:
495512
"""
496513
Initialize the CohereChatGenerator instance.
497514
@@ -621,7 +638,10 @@ def run(
621638
"""
622639

623640
# update generation kwargs by merging with the generation kwargs passed to the run method
624-
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
641+
generation_kwargs = {
642+
**self.generation_kwargs,
643+
**(generation_kwargs or {}),
644+
}
625645

626646
# Handle tools
627647
tools = tools or self.tools
@@ -685,7 +705,10 @@ async def run_async(
685705
"""
686706

687707
# update generation kwargs by merging with the generation kwargs passed to the run method
688-
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
708+
generation_kwargs = {
709+
**self.generation_kwargs,
710+
**(generation_kwargs or {}),
711+
}
689712

690713
# Handle tools
691714
tools = tools or self.tools

integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(
3737
streaming_callback: Callable | None = None,
3838
api_base_url: str | None = None,
3939
**kwargs: Any,
40-
):
40+
) -> None:
4141
"""
4242
Instantiates a `CohereGenerator` component.
4343

integrations/cohere/src/haystack_integrations/components/rankers/cohere/ranker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(
4141
meta_fields_to_embed: list[str] | None = None,
4242
meta_data_separator: str = "\n",
4343
max_tokens_per_doc: int = 4096,
44-
):
44+
) -> None:
4545
"""
4646
Creates an instance of the 'CohereRanker'.
4747

integrations/cohere/tests/test_chat_generator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,6 @@ def test_run_image(self):
462462
mock_response = MagicMock()
463463
mock_response.message.content = [MagicMock()]
464464
mock_response.message.content[0].text = "This is a test image response"
465-
mock_response.message.content[0].type = "text"
466465
mock_response.message.tool_calls = None
467466
mock_response.finish_reason = "COMPLETE"
468467
mock_response.usage = None

0 commit comments

Comments
 (0)