Skip to content

Commit 57074b8

Browse files
committed
cheap model
1 parent f8d84cd commit 57074b8

2 files changed

Lines changed: 10 additions & 10 deletions

File tree

integrations/cohere/tests/test_chat_generator.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ def test_run_image(self):
499499
class TestCohereChatGeneratorInference:
500500
def test_live_run(self):
501501
chat_messages = [ChatMessage.from_user("What's the capital of France")]
502-
component = CohereChatGenerator(generation_kwargs={"temperature": 0.8})
502+
component = CohereChatGenerator(model="command-r7b-12-2024", generation_kwargs={"temperature": 0.8})
503503
results = component.run(chat_messages)
504504
assert len(results["replies"]) == 1
505505
message: ChatMessage = results["replies"][0]
@@ -525,7 +525,7 @@ def __call__(self, chunk: StreamingChunk) -> None:
525525
self.responses += chunk.content if chunk.content else ""
526526

527527
callback = Callback()
528-
component = CohereChatGenerator(streaming_callback=callback)
528+
component = CohereChatGenerator(model="command-r7b-12-2024", streaming_callback=callback)
529529
results = component.run([ChatMessage.from_user("What's the capital of France? answer in a word")])
530530

531531
assert len(results["replies"]) == 1
@@ -559,7 +559,7 @@ def test_tools_use_old_way(self):
559559
},
560560
}
561561
]
562-
client = CohereChatGenerator()
562+
client = CohereChatGenerator(model="command-r7b-12-2024")
563563
response = client.run(
564564
messages=[ChatMessage.from_user("What is the current price of AAPL?")],
565565
generation_kwargs={"tools": tools_schema},
@@ -595,7 +595,7 @@ def test_tools_use_with_tools(self):
595595
function=stock_price,
596596
)
597597
initial_messages = [ChatMessage.from_user("What is the current price of AAPL?")]
598-
client = CohereChatGenerator()
598+
client = CohereChatGenerator(model="command-r7b-12-2024")
599599
response = client.run(
600600
messages=initial_messages,
601601
tools=[stock_price_tool],
@@ -650,7 +650,7 @@ def test_live_run_with_tools_streaming(self):
650650

651651
initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")]
652652
component = CohereChatGenerator(
653-
# Cohere's model that supports tools
653+
model="command-r7b-12-2024",
654654
tools=[weather_tool],
655655
streaming_callback=print_streaming_chunk,
656656
)
@@ -787,7 +787,7 @@ def test_live_run_with_mixed_tools(self):
787787
initial_messages = [
788788
ChatMessage.from_user("What's the weather like in Paris and what is the population of Berlin?")
789789
]
790-
component = CohereChatGenerator(model="command-r-08-2024", tools=mixed_tools)
790+
component = CohereChatGenerator(model="command-r7b-12-2024", tools=mixed_tools)
791791
results = component.run(messages=initial_messages)
792792

793793
assert len(results["replies"]) > 0, "No replies received"

integrations/cohere/tests/test_chat_generator_async.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def stock_price(ticker: str):
2525
class TestCohereChatGeneratorAsyncInference:
2626
async def test_live_run_async(self):
2727
chat_messages = [ChatMessage.from_user("What's the capital of France")]
28-
component = CohereChatGenerator(generation_kwargs={"temperature": 0.8})
28+
component = CohereChatGenerator(model="command-r7b-12-2024", generation_kwargs={"temperature": 0.8})
2929
results = await component.run_async(chat_messages)
3030
assert len(results["replies"]) == 1
3131
message: ChatMessage = results["replies"][0]
@@ -44,7 +44,7 @@ async def callback(chunk: StreamingChunk):
4444
counter += 1
4545
responses += chunk.content if chunk.content else ""
4646

47-
component = CohereChatGenerator(streaming_callback=callback)
47+
component = CohereChatGenerator(model="command-r7b-12-2024", streaming_callback=callback)
4848
results = await component.run_async([ChatMessage.from_user("What's the capital of France? answer in a word")])
4949

5050
assert len(results["replies"]) == 1
@@ -74,7 +74,7 @@ async def test_tools_use_with_tools_async(self):
7474
function=stock_price,
7575
)
7676
initial_messages = [ChatMessage.from_user("What is the current price of AAPL?")]
77-
client = CohereChatGenerator()
77+
client = CohereChatGenerator(model="command-r7b-12-2024")
7878
response = await client.run_async(
7979
messages=initial_messages,
8080
tools=[stock_price_tool],
@@ -137,7 +137,7 @@ async def print_streaming_chunk_async(chunk: StreamingChunk) -> None:
137137

138138
initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")]
139139
component = CohereChatGenerator(
140-
# Cohere's model that supports tools
140+
model="command-r7b-12-2024",
141141
tools=[weather_tool],
142142
streaming_callback=print_streaming_chunk_async,
143143
)

0 commit comments

Comments
 (0)