Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions integrations/cohere/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,22 @@ name = "cohere-haystack"
dynamic = ["version"]
description = ''
readme = "README.md"
requires-python = ">=3.9"
requires-python = ">=3.10"
license = "Apache-2.0"
keywords = []
authors = [{ name = "deepset GmbH", email = "info@deepset.ai" }]
classifiers = [
"License :: OSI Approved :: Apache Software License",
"Development Status :: 4 - Beta",
"Programming Language :: Python",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = ["haystack-ai>=2.19.0", "cohere>=5.17.0"]
dependencies = ["haystack-ai>=2.22.0", "cohere>=5.17.0"]

[project.urls]
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/cohere#readme"
Expand Down Expand Up @@ -79,7 +78,6 @@ check_untyped_defs = true
disallow_incomplete_defs = true

[tool.ruff]
target-version = "py39"
line-length = 120

[tool.ruff.lint]
Expand Down Expand Up @@ -126,10 +124,6 @@ ignore = [
"B008",
"S101",
]
unfixable = [
# Don't touch unused imports
"F401",
]

[tool.ruff.lint.isort]
known-first-party = ["haystack_integrations"]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Optional, Union
from typing import Any

from haystack import Document, component, default_from_dict, default_to_dict
from haystack.utils import Secret, deserialize_secrets_inplace
Expand Down Expand Up @@ -45,9 +45,9 @@ def __init__(
timeout: float = 120.0,
batch_size: int = 32,
progress_bar: bool = True,
meta_fields_to_embed: Optional[list[str]] = None,
meta_fields_to_embed: list[str] | None = None,
embedding_separator: str = "\n",
embedding_type: Optional[EmbeddingTypes] = None,
embedding_type: EmbeddingTypes | None = None,
):
"""
:param api_key: the Cohere API key.
Expand Down Expand Up @@ -167,7 +167,7 @@ def _prepare_texts_to_embed(self, documents: list[Document]) -> list[str]:
return texts_to_embed

@component.output_types(documents=list[Document], meta=dict[str, Any])
def run(self, documents: list[Document]) -> dict[str, Union[list[Document], dict[str, Any]]]:
def run(self, documents: list[Document]) -> dict[str, list[Document] | dict[str, Any]]:
"""Embed a list of `Documents`.

:param documents: documents to embed.
Expand Down Expand Up @@ -195,13 +195,13 @@ def run(self, documents: list[Document]) -> dict[str, Union[list[Document], dict
self.embedding_type,
)

for doc, embeddings in zip(documents, all_embeddings):
for doc, embeddings in zip(documents, all_embeddings, strict=True):
doc.embedding = embeddings

return {"documents": documents, "meta": metadata}

@component.output_types(documents=list[Document], meta=dict[str, Any])
async def run_async(self, documents: list[Document]) -> dict[str, Union[list[Document], dict[str, Any]]]:
async def run_async(self, documents: list[Document]) -> dict[str, list[Document] | dict[str, Any]]:
"""
Embed a list of `Documents` asynchronously.

Expand All @@ -228,7 +228,7 @@ async def run_async(self, documents: list[Document]) -> dict[str, Union[list[Doc
embedding_type=self.embedding_type,
)

for doc, embeddings in zip(documents, all_embeddings):
for doc, embeddings in zip(documents, all_embeddings, strict=True):
doc.embedding = embeddings

return {"documents": documents, "meta": metadata}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

from dataclasses import replace
from typing import Any, Optional
from typing import Any

from haystack import Document, component, default_from_dict, default_to_dict, logging
from haystack.components.converters.image.image_utils import (
Expand Down Expand Up @@ -63,13 +63,13 @@ def __init__(
self,
*,
file_path_meta_field: str = "file_path",
root_path: Optional[str] = None,
image_size: Optional[tuple[int, int]] = None,
root_path: str | None = None,
image_size: tuple[int, int] | None = None,
api_key: Secret = Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]),
model: str = "embed-v4.0",
api_base_url: str = "https://api.cohere.com",
timeout: float = 120.0,
embedding_dimension: Optional[int] = None,
embedding_dimension: int | None = None,
embedding_type: EmbeddingTypes = EmbeddingTypes.FLOAT,
progress_bar: bool = True,
) -> None:
Expand Down Expand Up @@ -205,7 +205,7 @@ def _extract_images_to_embed(self, documents: list[Document]) -> list[str]:
)
raise ValueError(msg)

images_to_embed: list[Optional[str]] = [None] * len(documents)
images_to_embed: list[str | None] = [None] * len(documents)
pdf_page_infos: list[_PDFPageInfo] = []

for doc_idx, image_source_info in enumerate(images_source_info):
Expand Down Expand Up @@ -259,7 +259,9 @@ def run(self, documents: list[Document]) -> dict[str, list[Document]]:
embeddings = []

# The Cohere API only supports passing one image at a time
for doc, image in tqdm(zip(documents, images_to_embed), desc="Embedding images", disable=not self.progress_bar):
for doc, image in tqdm(
zip(documents, images_to_embed, strict=True), desc="Embedding images", disable=not self.progress_bar
):
try:
response = self._client.embed(
model=self.model,
Expand All @@ -276,7 +278,7 @@ def run(self, documents: list[Document]) -> dict[str, list[Document]]:
embeddings.append(embedding)

docs_with_embeddings = []
for doc, emb in zip(documents, embeddings):
for doc, emb in zip(documents, embeddings, strict=True):
# we store this information for later inspection
new_meta = {
**doc.meta,
Expand Down Expand Up @@ -305,7 +307,9 @@ async def run_async(self, documents: list[Document]) -> dict[str, list[Document]
embeddings = []

# The Cohere API only supports passing one image at a time
for doc, image in tqdm(zip(documents, images_to_embed), desc="Embedding images", disable=not self.progress_bar):
for doc, image in tqdm(
zip(documents, images_to_embed, strict=True), desc="Embedding images", disable=not self.progress_bar
):
try:
response = await self._async_client.embed(
model=self.model,
Expand All @@ -322,7 +326,7 @@ async def run_async(self, documents: list[Document]) -> dict[str, list[Document]
embeddings.append(embedding)

docs_with_embeddings = []
for doc, emb in zip(documents, embeddings):
for doc, emb in zip(documents, embeddings, strict=True):
# we store this information for later inspection
new_meta = {
**doc.meta,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Optional, Union
from typing import Any

from haystack import component, default_from_dict, default_to_dict
from haystack.utils import Secret, deserialize_secrets_inplace
Expand Down Expand Up @@ -40,7 +40,7 @@ def __init__(
api_base_url: str = "https://api.cohere.com",
truncate: str = "END",
timeout: float = 120.0,
embedding_type: Optional[EmbeddingTypes] = None,
embedding_type: EmbeddingTypes | None = None,
):
"""
:param api_key: the Cohere API key.
Expand Down Expand Up @@ -134,7 +134,7 @@ def from_dict(cls, data: dict[str, Any]) -> "CohereTextEmbedder":
return default_from_dict(cls, data)

@component.output_types(embedding=list[float], meta=dict[str, Any])
def run(self, text: str) -> dict[str, Union[list[float], dict[str, Any]]]:
def run(self, text: str) -> dict[str, list[float] | dict[str, Any]]:
"""
Embed text.

Expand All @@ -161,7 +161,7 @@ def run(self, text: str) -> dict[str, Union[list[float], dict[str, Any]]]:
return {"embedding": embedding[0], "meta": metadata}

@component.output_types(embedding=list[float], meta=dict[str, Any])
async def run_async(self, text: str) -> dict[str, Union[list[float], dict[str, Any]]]:
async def run_async(self, text: str) -> dict[str, list[float] | dict[str, Any]]:
"""
Asynchronously embed text.

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Optional
from typing import Any

from tqdm import tqdm

Expand All @@ -16,7 +16,7 @@ async def get_async_response(
model_name: str,
input_type: str,
truncate: str,
embedding_type: Optional[EmbeddingTypes] = None,
embedding_type: EmbeddingTypes | None = None,
) -> tuple[list[list[float]], dict[str, Any]]:
"""Embeds a list of texts asynchronously using the Cohere API.

Expand Down Expand Up @@ -64,7 +64,7 @@ def get_response(
truncate: str,
batch_size: int = 32,
progress_bar: bool = False,
embedding_type: Optional[EmbeddingTypes] = None,
embedding_type: EmbeddingTypes | None = None,
) -> tuple[list[list[float]], dict[str, Any]]:
"""Embeds a list of texts using the Cohere API.

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from collections.abc import AsyncIterator, Iterator
from typing import Any, Literal, Optional, Union, get_args
from typing import Any, Literal, get_args

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message
Expand Down Expand Up @@ -77,7 +77,7 @@ def _format_tool(tool: Tool) -> dict[str, Any]:

def _format_message(
message: ChatMessage,
) -> Union[UserChatMessageV2, AssistantChatMessageV2, SystemChatMessageV2, ToolChatMessageV2]:
) -> UserChatMessageV2 | AssistantChatMessageV2 | SystemChatMessageV2 | ToolChatMessageV2:
"""
Formats a Haystack ChatMessage into Cohere's chat format.

Expand Down Expand Up @@ -147,7 +147,7 @@ def _format_message(
raise ValueError(msg)

# Build multimodal content following Cohere's API specification
content_parts: list[Union[CohereTextContent, ImageUrlContent]] = []
content_parts: list[CohereTextContent | ImageUrlContent] = []
for part in message._content:
if isinstance(part, TextContent) and part.text:
text_content = CohereTextContent(text=part.text)
Expand Down Expand Up @@ -234,7 +234,7 @@ def _parse_response(chat_response: ChatResponse, model: str) -> ChatMessage:
def _convert_cohere_chunk_to_streaming_chunk(
chunk: StreamedChatResponseV2,
model: str,
component_info: Optional[ComponentInfo] = None,
component_info: ComponentInfo | None = None,
global_index: int = 0,
) -> StreamingChunk:
"""
Expand Down Expand Up @@ -518,10 +518,10 @@ def __init__(
self,
api_key: Secret = Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]),
model: str = "command-a-03-2025",
streaming_callback: Optional[StreamingCallbackT] = None,
api_base_url: Optional[str] = None,
generation_kwargs: Optional[dict[str, Any]] = None,
tools: Optional[ToolsType] = None,
streaming_callback: StreamingCallbackT | None = None,
api_base_url: str | None = None,
generation_kwargs: dict[str, Any] | None = None,
tools: ToolsType | None = None,
**kwargs: Any,
):
"""
Expand Down Expand Up @@ -618,9 +618,9 @@ def from_dict(cls, data: dict[str, Any]) -> "CohereChatGenerator":
def run(
self,
messages: list[ChatMessage],
generation_kwargs: Optional[dict[str, Any]] = None,
tools: Optional[ToolsType] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
generation_kwargs: dict[str, Any] | None = None,
tools: ToolsType | None = None,
streaming_callback: StreamingCallbackT | None = None,
) -> dict[str, list[ChatMessage]]:
"""
Invoke the chat endpoint based on the provided messages and generation parameters.
Expand Down Expand Up @@ -685,9 +685,9 @@ def run(
async def run_async(
self,
messages: list[ChatMessage],
generation_kwargs: Optional[dict[str, Any]] = None,
tools: Optional[ToolsType] = None,
streaming_callback: Optional[StreamingCallbackT] = None,
generation_kwargs: dict[str, Any] | None = None,
tools: ToolsType | None = None,
streaming_callback: StreamingCallbackT | None = None,
) -> dict[str, list[ChatMessage]]:
"""
Asynchronously invoke the chat endpoint based on the provided messages and generation parameters.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Optional, Union
from collections.abc import Callable
from typing import Any

from haystack import component, logging
from haystack.dataclasses import ChatMessage
Expand Down Expand Up @@ -33,8 +34,8 @@ def __init__(
self,
api_key: Secret = Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]),
model: str = "command-a-03-2025",
streaming_callback: Optional[Callable] = None,
api_base_url: Optional[str] = None,
streaming_callback: Callable | None = None,
api_base_url: str | None = None,
**kwargs: Any,
):
"""
Expand All @@ -57,7 +58,7 @@ def __init__(
def run( # type: ignore[override] # due to incompatible signature with ChatGenerator
self,
prompt: str,
) -> dict[str, Union[list[str], list[dict[str, Any]]]]:
) -> dict[str, list[str] | list[dict[str, Any]]]:
"""
Queries the LLM with the prompts to produce replies.

Expand All @@ -80,7 +81,7 @@ def run( # type: ignore[override] # due to incompatible signature with ChatGene
async def run_async( # type: ignore[override] # due to incompatible signature with ChatGenerator
self,
prompt: str,
) -> dict[str, Union[list[str], list[dict[str, Any]]]]:
) -> dict[str, list[str] | list[dict[str, Any]]]:
"""
Queries the LLM asynchronously with the prompts to produce replies.
:param prompt: the prompt to be sent to the generative model.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional
from typing import Any

from haystack import Document, component, default_from_dict, default_to_dict, logging
from haystack.utils import Secret, deserialize_secrets_inplace
Expand Down Expand Up @@ -37,7 +37,7 @@ def __init__(
top_k: int = 10,
api_key: Secret = Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]),
api_base_url: str = "https://api.cohere.com",
meta_fields_to_embed: Optional[list[str]] = None,
meta_fields_to_embed: list[str] | None = None,
meta_data_separator: str = "\n",
max_tokens_per_doc: int = 4096,
):
Expand Down Expand Up @@ -120,7 +120,7 @@ def _prepare_cohere_input_docs(self, documents: list[Document]) -> list[str]:
return concatenated_input_list

@component.output_types(documents=list[Document])
def run(self, query: str, documents: list[Document], top_k: Optional[int] = None) -> dict[str, list[Document]]:
def run(self, query: str, documents: list[Document], top_k: int | None = None) -> dict[str, list[Document]]:
"""
Use the Cohere Reranker to re-rank the list of documents based on the query.

Expand Down Expand Up @@ -160,7 +160,7 @@ def run(self, query: str, documents: list[Document], top_k: Optional[int] = None
indices = [output.index for output in response.results]
scores = [output.relevance_score for output in response.results]
sorted_docs = []
for idx, score in zip(indices, scores):
for idx, score in zip(indices, scores, strict=True):
doc = documents[idx]
doc.score = score
sorted_docs.append(documents[idx])
Expand Down
Loading