Skip to content

Commit 2028be4

Browse files
authored
chore!: CohereChatGenerator - remove **kwargs init parameter + fix test (#2948)
* fix Cohere test * simplify * better client handling * update haystack pin * cheap model
1 parent f2c3e1b commit 2028be4

6 files changed

Lines changed: 39 additions & 34 deletions

File tree

integrations/cohere/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ classifiers = [
2222
"Programming Language :: Python :: Implementation :: CPython",
2323
"Programming Language :: Python :: Implementation :: PyPy",
2424
]
25-
dependencies = ["haystack-ai>=2.22.0", "cohere>=5.17.0"]
25+
dependencies = ["haystack-ai>=2.23.0", "cohere>=5.17.0"]
2626

2727
[project.urls]
2828
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/cohere#readme"

integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
flatten_tools_or_toolsets,
2323
serialize_tools_or_toolset,
2424
)
25-
from haystack.utils import Secret, deserialize_secrets_inplace
25+
from haystack.utils import Secret
2626
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
2727
from httpx import AsyncClient as AsyncHTTPXClient
2828
from httpx import AsyncHTTPTransport, HTTPTransport
@@ -508,7 +508,6 @@ def __init__(
508508
*,
509509
timeout: float | None = None,
510510
max_retries: int | None = None,
511-
**kwargs: Any,
512511
):
513512
"""
514513
Initialize the CohereChatGenerator instance.
@@ -537,23 +536,18 @@ def __init__(
537536
:param max_retries:
538537
Maximum number of retries to attempt for failed requests. If not set, it defaults to the default set by
539538
the Cohere client.
540-
:param kwargs:
541-
Additional generation parameters. These are merged into `generation_kwargs` for backward compatibility.
542539
543540
"""
544541
_check_duplicate_tool_names(flatten_tools_or_toolsets(tools))
545542

546543
if not api_base_url:
547544
api_base_url = "https://api.cohere.com"
548-
if generation_kwargs is None:
549-
generation_kwargs = {}
550-
if kwargs:
551-
generation_kwargs = {**generation_kwargs, **kwargs}
545+
552546
self.api_key = api_key
553547
self.model = model
554548
self.streaming_callback = streaming_callback
555549
self.api_base_url = api_base_url
556-
self.generation_kwargs = generation_kwargs
550+
self.generation_kwargs = generation_kwargs or {}
557551
self.tools = tools
558552
self.timeout = timeout
559553
self.max_retries = max_retries
@@ -565,14 +559,15 @@ def __init__(
565559
}
566560
if timeout is not None:
567561
client_kwargs["timeout"] = timeout
562+
563+
sync_kwargs = {**client_kwargs}
564+
async_kwargs = {**client_kwargs}
568565
if max_retries is not None:
569-
sync_httpx_client = HTTPXClient(transport=HTTPTransport(retries=max_retries))
570-
async_httpx_client = AsyncHTTPXClient(transport=AsyncHTTPTransport(retries=max_retries))
571-
self.client = ClientV2(**client_kwargs, httpx_client=sync_httpx_client)
572-
self.async_client = AsyncClientV2(**client_kwargs, httpx_client=async_httpx_client)
573-
else:
574-
self.client = ClientV2(**client_kwargs)
575-
self.async_client = AsyncClientV2(**client_kwargs)
566+
sync_kwargs["httpx_client"] = HTTPXClient(transport=HTTPTransport(retries=max_retries))
567+
async_kwargs["httpx_client"] = AsyncHTTPXClient(transport=AsyncHTTPTransport(retries=max_retries))
568+
569+
self.client = ClientV2(**sync_kwargs)
570+
self.async_client = AsyncClientV2(**async_kwargs)
576571

577572
def _get_telemetry_data(self) -> dict[str, Any]:
578573
"""
@@ -593,7 +588,7 @@ def to_dict(self) -> dict[str, Any]:
593588
model=self.model,
594589
streaming_callback=callback_name,
595590
api_base_url=self.api_base_url,
596-
api_key=self.api_key.to_dict(),
591+
api_key=self.api_key,
597592
generation_kwargs=self.generation_kwargs,
598593
tools=serialize_tools_or_toolset(self.tools),
599594
timeout=self.timeout,
@@ -611,7 +606,6 @@ def from_dict(cls, data: dict[str, Any]) -> "CohereChatGenerator":
611606
Deserialized component.
612607
"""
613608
init_params = data.get("init_parameters", {})
614-
deserialize_secrets_inplace(init_params, ["api_key"])
615609
deserialize_tools_or_toolset_inplace(init_params, key="tools")
616610
serialized_callback_handler = init_params.get("streaming_callback")
617611
if serialized_callback_handler:

integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,21 @@ def __init__(
5151
You can check them in model's documentation.
5252
"""
5353

54+
# from_dict deserialization, where `generation_kwargs` is in **kwargs
55+
if "generation_kwargs" in kwargs:
56+
generation_kwargs = kwargs.pop("generation_kwargs")
57+
else:
58+
# direct construction like `CohereGenerator(max_tokens=10)`
59+
generation_kwargs = kwargs
60+
5461
# Note we have to call super() like this because of the way components are dynamically built with the decorator
55-
super(CohereGenerator, self).__init__(api_key, model, streaming_callback, api_base_url, None, **kwargs) # noqa
62+
super(CohereGenerator, self).__init__( # noqa: UP008
63+
api_key=api_key,
64+
model=model,
65+
streaming_callback=streaming_callback,
66+
api_base_url=api_base_url,
67+
generation_kwargs=generation_kwargs,
68+
)
5669

5770
@component.output_types(replies=list[str], meta=list[dict[str, Any]])
5871
def run( # type: ignore[override] # due to incompatible signature with ChatGenerator

integrations/cohere/tests/test_chat_generator.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ def test_run_image(self):
499499
class TestCohereChatGeneratorInference:
500500
def test_live_run(self):
501501
chat_messages = [ChatMessage.from_user("What's the capital of France")]
502-
component = CohereChatGenerator(generation_kwargs={"temperature": 0.8})
502+
component = CohereChatGenerator(model="command-r7b-12-2024", generation_kwargs={"temperature": 0.8})
503503
results = component.run(chat_messages)
504504
assert len(results["replies"]) == 1
505505
message: ChatMessage = results["replies"][0]
@@ -525,7 +525,7 @@ def __call__(self, chunk: StreamingChunk) -> None:
525525
self.responses += chunk.content if chunk.content else ""
526526

527527
callback = Callback()
528-
component = CohereChatGenerator(streaming_callback=callback, stream=True)
528+
component = CohereChatGenerator(model="command-r7b-12-2024", streaming_callback=callback)
529529
results = component.run([ChatMessage.from_user("What's the capital of France? answer in a word")])
530530

531531
assert len(results["replies"]) == 1
@@ -559,7 +559,7 @@ def test_tools_use_old_way(self):
559559
},
560560
}
561561
]
562-
client = CohereChatGenerator()
562+
client = CohereChatGenerator(model="command-r7b-12-2024")
563563
response = client.run(
564564
messages=[ChatMessage.from_user("What is the current price of AAPL?")],
565565
generation_kwargs={"tools": tools_schema},
@@ -595,7 +595,7 @@ def test_tools_use_with_tools(self):
595595
function=stock_price,
596596
)
597597
initial_messages = [ChatMessage.from_user("What is the current price of AAPL?")]
598-
client = CohereChatGenerator()
598+
client = CohereChatGenerator(model="command-r7b-12-2024")
599599
response = client.run(
600600
messages=initial_messages,
601601
tools=[stock_price_tool],
@@ -650,7 +650,7 @@ def test_live_run_with_tools_streaming(self):
650650

651651
initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")]
652652
component = CohereChatGenerator(
653-
# Cohere's model that supports tools
653+
model="command-r7b-12-2024",
654654
tools=[weather_tool],
655655
streaming_callback=print_streaming_chunk,
656656
)
@@ -702,7 +702,7 @@ def test_pipeline_with_cohere_chat_generator(self):
702702
)
703703

704704
pipeline = Pipeline()
705-
pipeline.add_component("generator", CohereChatGenerator(model="command-r-08-2024", tools=[weather_tool]))
705+
pipeline.add_component("generator", CohereChatGenerator(model="command-r7b-12-2024", tools=[weather_tool]))
706706
pipeline.add_component("tool_invoker", ToolInvoker(tools=[weather_tool]))
707707

708708
pipeline.connect("generator", "tool_invoker")
@@ -787,7 +787,7 @@ def test_live_run_with_mixed_tools(self):
787787
initial_messages = [
788788
ChatMessage.from_user("What's the weather like in Paris and what is the population of Berlin?")
789789
]
790-
component = CohereChatGenerator(model="command-r-08-2024", tools=mixed_tools)
790+
component = CohereChatGenerator(model="command-r7b-12-2024", tools=mixed_tools)
791791
results = component.run(messages=initial_messages)
792792

793793
assert len(results["replies"]) > 0, "No replies received"

integrations/cohere/tests/test_chat_generator_async.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def stock_price(ticker: str):
2525
class TestCohereChatGeneratorAsyncInference:
2626
async def test_live_run_async(self):
2727
chat_messages = [ChatMessage.from_user("What's the capital of France")]
28-
component = CohereChatGenerator(generation_kwargs={"temperature": 0.8})
28+
component = CohereChatGenerator(model="command-r7b-12-2024", generation_kwargs={"temperature": 0.8})
2929
results = await component.run_async(chat_messages)
3030
assert len(results["replies"]) == 1
3131
message: ChatMessage = results["replies"][0]
@@ -44,7 +44,7 @@ async def callback(chunk: StreamingChunk):
4444
counter += 1
4545
responses += chunk.content if chunk.content else ""
4646

47-
component = CohereChatGenerator(streaming_callback=callback)
47+
component = CohereChatGenerator(model="command-r7b-12-2024", streaming_callback=callback)
4848
results = await component.run_async([ChatMessage.from_user("What's the capital of France? answer in a word")])
4949

5050
assert len(results["replies"]) == 1
@@ -74,7 +74,7 @@ async def test_tools_use_with_tools_async(self):
7474
function=stock_price,
7575
)
7676
initial_messages = [ChatMessage.from_user("What is the current price of AAPL?")]
77-
client = CohereChatGenerator()
77+
client = CohereChatGenerator(model="command-r7b-12-2024")
7878
response = await client.run_async(
7979
messages=initial_messages,
8080
tools=[stock_price_tool],
@@ -137,7 +137,7 @@ async def print_streaming_chunk_async(chunk: StreamingChunk) -> None:
137137

138138
initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")]
139139
component = CohereChatGenerator(
140-
# Cohere's model that supports tools
140+
model="command-r7b-12-2024",
141141
tools=[weather_tool],
142142
streaming_callback=print_streaming_chunk_async,
143143
)

integrations/cohere/tests/test_generator.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,10 @@ def test_from_dict(self, monkeypatch):
9191
"type": "haystack_integrations.components.generators.cohere.generator.CohereGenerator",
9292
"init_parameters": {
9393
"model": "command-a-03-2025",
94-
"max_tokens": 10,
94+
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
9595
"api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
96-
"some_test_param": "test-params",
9796
"api_base_url": "test-base-url",
9897
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
99-
"tools": None,
10098
},
10199
}
102100
component: CohereGenerator = CohereGenerator.from_dict(data)

0 commit comments

Comments
 (0)