Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
)
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
from httpx import AsyncClient as AsyncHTTPXClient
from httpx import AsyncHTTPTransport, HTTPTransport
from httpx import Client as HTTPXClient

from cohere import AsyncClientV2, ChatResponse, ClientV2, StreamedChatResponseV2

Expand Down Expand Up @@ -502,6 +505,9 @@ def __init__(
api_base_url: str | None = None,
generation_kwargs: dict[str, Any] | None = None,
tools: ToolsType | None = None,
*,
timeout: float | None = None,
max_retries: int | None = None,
**kwargs: Any,
):
"""
Expand All @@ -526,6 +532,13 @@ def __init__(
mean less random generations.
:param tools: A list of Tool and/or Toolset objects, or a single Toolset that the model can use.
Each tool should have a unique name.
:param timeout:
Timeout for Cohere client calls. If not set, it defaults to the default set by the Cohere client.
:param max_retries:
Maximum number of retries to attempt for failed requests. If not set, it defaults to the default set by
the Cohere client.
:param kwargs:
Additional generation parameters. These are merged into `generation_kwargs` for backward compatibility.

"""
_check_duplicate_tool_names(flatten_tools_or_toolsets(tools))
Expand All @@ -534,23 +547,32 @@ def __init__(
api_base_url = "https://api.cohere.com"
if generation_kwargs is None:
generation_kwargs = {}
if kwargs:
generation_kwargs = {**generation_kwargs, **kwargs}
self.api_key = api_key
self.model = model
self.streaming_callback = streaming_callback
self.api_base_url = api_base_url
self.generation_kwargs = generation_kwargs
self.tools = tools
self.model_parameters = kwargs
self.client = ClientV2(
api_key=self.api_key.resolve_value(),
base_url=self.api_base_url,
client_name="haystack",
)
self.async_client = AsyncClientV2(
api_key=self.api_key.resolve_value(),
base_url=self.api_base_url,
client_name="haystack",
)
self.timeout = timeout
self.max_retries = max_retries

client_kwargs: dict[str, Any] = {
"api_key": self.api_key.resolve_value(),
"base_url": self.api_base_url,
"client_name": "haystack",
}
if timeout is not None:
client_kwargs["timeout"] = timeout
if max_retries is not None:
sync_httpx_client = HTTPXClient(transport=HTTPTransport(retries=max_retries))
async_httpx_client = AsyncHTTPXClient(transport=AsyncHTTPTransport(retries=max_retries))
self.client = ClientV2(**client_kwargs, httpx_client=sync_httpx_client)
self.async_client = AsyncClientV2(**client_kwargs, httpx_client=async_httpx_client)
else:
self.client = ClientV2(**client_kwargs)
self.async_client = AsyncClientV2(**client_kwargs)

def _get_telemetry_data(self) -> dict[str, Any]:
"""
Expand All @@ -574,6 +596,8 @@ def to_dict(self) -> dict[str, Any]:
api_key=self.api_key.to_dict(),
generation_kwargs=self.generation_kwargs,
tools=serialize_tools_or_toolset(self.tools),
timeout=self.timeout,
max_retries=self.max_retries,
)

@classmethod
Expand Down
18 changes: 18 additions & 0 deletions integrations/cohere/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ def test_init_default(self, monkeypatch):
assert component.streaming_callback is None
assert component.api_base_url == "https://api.cohere.com"
assert not component.generation_kwargs
assert component.timeout is None
assert component.max_retries is None

def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("COHERE_API_KEY", raising=False)
Expand All @@ -174,6 +176,8 @@ def test_init_with_parameters(self):
"max_tokens": 10,
"some_test_param": "test-params",
},
timeout=30.0,
max_retries=5,
)
assert component.api_key == Secret.from_token("test-api-key")
assert component.model == "command-nightly"
Expand All @@ -183,6 +187,8 @@ def test_init_with_parameters(self):
"max_tokens": 10,
"some_test_param": "test-params",
}
assert component.timeout == 30.0
assert component.max_retries == 5

def test_to_dict_default(self, monkeypatch):
monkeypatch.setenv("COHERE_API_KEY", "test-api-key")
Expand All @@ -201,6 +207,8 @@ def test_to_dict_default(self, monkeypatch):
"api_base_url": "https://api.cohere.com",
"generation_kwargs": {},
"tools": None,
"timeout": None,
"max_retries": None,
},
}

Expand Down Expand Up @@ -234,6 +242,8 @@ def test_to_dict_with_parameters(self, monkeypatch):
"some_test_param": "test-params",
},
"tools": None,
"timeout": None,
"max_retries": None,
},
}

Expand All @@ -255,6 +265,8 @@ def test_from_dict(self, monkeypatch):
"max_tokens": 10,
"some_test_param": "test-params",
},
"timeout": None,
"max_retries": None,
},
}
component = CohereChatGenerator.from_dict(data)
Expand All @@ -265,6 +277,8 @@ def test_from_dict(self, monkeypatch):
"max_tokens": 10,
"some_test_param": "test-params",
}
assert component.timeout is None
assert component.max_retries is None

def test_from_dict_fail_wo_env_var(self, monkeypatch):
monkeypatch.delenv("COHERE_API_KEY", raising=False)
Expand All @@ -284,6 +298,8 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch):
"max_tokens": 10,
"some_test_param": "test-params",
},
"timeout": None,
"max_retries": None,
},
}
with pytest.raises(ValueError):
Expand Down Expand Up @@ -327,6 +343,8 @@ def test_serde_in_pipeline(self, monkeypatch):
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"api_base_url": "https://api.cohere.com",
"generation_kwargs": {"temperature": 0.7},
"timeout": None,
"max_retries": None,
"tools": [
{
"type": "haystack.tools.tool.Tool",
Expand Down
12 changes: 8 additions & 4 deletions integrations/cohere/tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_init_default(self, monkeypatch):
assert component.model == "command-a-03-2025"
assert component.streaming_callback is None
assert component.api_base_url == COHERE_API_URL
assert component.model_parameters == {}
assert component.generation_kwargs == {}

def test_init_with_parameters(self):
callback = lambda x: x # noqa: E731
Expand All @@ -38,7 +38,7 @@ def test_init_with_parameters(self):
assert component.model == "command-light"
assert component.streaming_callback == callback
assert component.api_base_url == "test-base-url"
assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"}
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}

def test_to_dict_default(self, monkeypatch):
monkeypatch.setenv("COHERE_API_KEY", "test-api-key")
Expand All @@ -53,6 +53,8 @@ def test_to_dict_default(self, monkeypatch):
"api_base_url": COHERE_API_URL,
"generation_kwargs": {},
"tools": None,
"timeout": None,
"max_retries": None,
},
}

Expand All @@ -75,8 +77,10 @@ def test_to_dict_with_parameters(self, monkeypatch):
"api_base_url": "test-base-url",
"api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"generation_kwargs": {},
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
"tools": None,
"timeout": None,
"max_retries": None,
},
}

Expand All @@ -100,7 +104,7 @@ def test_from_dict(self, monkeypatch):
assert component.model == "command-a-03-2025"
assert component.streaming_callback == print_streaming_chunk
assert component.api_base_url == "test-base-url"
assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"}
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}

@pytest.mark.skipif(
not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None),
Expand Down