Skip to content

Commit e830011

Browse files
authored
chore: Upgrade default model to command-r-08-2024 (#1691)
* Upgrade default model to command-r-plus * Pydoc update * Update default model to command-r-08-2024
1 parent 5e47f8c commit e830011

2 files changed

Lines changed: 14 additions & 14 deletions

File tree

integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ class CohereChatGenerator:
236236
from haystack.utils import Secret
237237
from haystack_integrations.components.generators.cohere import CohereChatGenerator
238238
239-
client = CohereChatGenerator(model="command-r", api_key=Secret.from_env_var("COHERE_API_KEY"))
239+
client = CohereChatGenerator(model="command-r-08-2024", api_key=Secret.from_env_var("COHERE_API_KEY"))
240240
messages = [ChatMessage.from_user("What's Natural Language Processing?")]
241241
client.run(messages)
242242
@@ -278,7 +278,7 @@ def weather(city: str) -> str:
278278
279279
# Create and set up the pipeline
280280
pipeline = Pipeline()
281-
pipeline.add_component("generator", CohereChatGenerator(model="command-r", tools=[weather_tool]))
281+
pipeline.add_component("generator", CohereChatGenerator(model="command-r-08-2024", tools=[weather_tool]))
282282
pipeline.add_component("tool_invoker", ToolInvoker(tools=[weather_tool]))
283283
pipeline.connect("generator", "tool_invoker")
284284
@@ -296,7 +296,7 @@ def weather(city: str) -> str:
296296
def __init__(
297297
self,
298298
api_key: Secret = Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]),
299-
model: str = "command-r",
299+
model: str = "command-r-08-2024",
300300
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
301301
api_base_url: Optional[str] = None,
302302
generation_kwargs: Optional[Dict[str, Any]] = None,

integrations/cohere/tests/test_cohere_chat_generator.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test_init_default(self, monkeypatch):
4646

4747
component = CohereChatGenerator()
4848
assert component.api_key == Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"])
49-
assert component.model == "command-r"
49+
assert component.model == "command-r-08-2024"
5050
assert component.streaming_callback is None
5151
assert component.api_base_url == "https://api.cohere.com"
5252
assert not component.generation_kwargs
@@ -78,7 +78,7 @@ def test_to_dict_default(self, monkeypatch):
7878
assert data == {
7979
"type": "haystack_integrations.components.generators.cohere.chat.chat_generator.CohereChatGenerator",
8080
"init_parameters": {
81-
"model": "command-r",
81+
"model": "command-r-08-2024",
8282
"streaming_callback": None,
8383
"api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"},
8484
"api_base_url": "https://api.cohere.com",
@@ -116,15 +116,15 @@ def test_from_dict(self, monkeypatch):
116116
data = {
117117
"type": "haystack_integrations.components.generators.cohere.chat.chat_generator.CohereChatGenerator",
118118
"init_parameters": {
119-
"model": "command-r",
119+
"model": "command-r-08-2024",
120120
"api_base_url": "test-base-url",
121121
"api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
122122
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
123123
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
124124
},
125125
}
126126
component = CohereChatGenerator.from_dict(data)
127-
assert component.model == "command-r"
127+
assert component.model == "command-r-08-2024"
128128
assert component.streaming_callback is print_streaming_chunk
129129
assert component.api_base_url == "test-base-url"
130130
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
@@ -135,7 +135,7 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch):
135135
data = {
136136
"type": "haystack_integrations.components.generators.cohere.chat.chat_generator.CohereChatGenerator",
137137
"init_parameters": {
138-
"model": "command-r",
138+
"model": "command-r-08-2024",
139139
"api_base_url": "test-base-url",
140140
"api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"},
141141
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
@@ -226,7 +226,7 @@ def test_tools_use_old_way(self):
226226
},
227227
}
228228
]
229-
client = CohereChatGenerator(model="command-r")
229+
client = CohereChatGenerator(model="command-r-08-2024")
230230
response = client.run(
231231
messages=[ChatMessage.from_user("What is the current price of AAPL?")],
232232
generation_kwargs={"tools": tools_schema},
@@ -267,7 +267,7 @@ def test_tools_use_with_tools(self):
267267
function=stock_price,
268268
)
269269
initial_messages = [ChatMessage.from_user("What is the current price of AAPL?")]
270-
client = CohereChatGenerator(model="command-r")
270+
client = CohereChatGenerator(model="command-r-08-2024")
271271
response = client.run(
272272
messages=initial_messages,
273273
tools=[stock_price_tool],
@@ -327,7 +327,7 @@ def test_live_run_with_tools_streaming(self):
327327

328328
initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")]
329329
component = CohereChatGenerator(
330-
model="command-r", # Cohere's model that supports tools
330+
model="command-r-08-2024", # Cohere's model that supports tools
331331
tools=[weather_tool],
332332
streaming_callback=print_streaming_chunk,
333333
)
@@ -384,7 +384,7 @@ def test_pipeline_with_cohere_chat_generator(self):
384384
)
385385

386386
pipeline = Pipeline()
387-
pipeline.add_component("generator", CohereChatGenerator(model="command-r", tools=[weather_tool]))
387+
pipeline.add_component("generator", CohereChatGenerator(model="command-r-08-2024", tools=[weather_tool]))
388388
pipeline.add_component("tool_invoker", ToolInvoker(tools=[weather_tool]))
389389

390390
pipeline.connect("generator", "tool_invoker")
@@ -416,7 +416,7 @@ def test_serde_in_pipeline(self, monkeypatch):
416416

417417
# Create generator with specific configuration
418418
generator = CohereChatGenerator(
419-
model="command-r",
419+
model="command-r-08-2024",
420420
generation_kwargs={"temperature": 0.7},
421421
streaming_callback=print_streaming_chunk,
422422
tools=[tool],
@@ -437,7 +437,7 @@ def test_serde_in_pipeline(self, monkeypatch):
437437
"generator": {
438438
"type": "haystack_integrations.components.generators.cohere.chat.chat_generator.CohereChatGenerator", # noqa: E501
439439
"init_parameters": {
440-
"model": "command-r",
440+
"model": "command-r-08-2024",
441441
"api_key": {"type": "env_var", "env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True},
442442
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
443443
"api_base_url": "https://api.cohere.com",

0 commit comments

Comments
 (0)