Skip to content

Commit 54f186e

Browse files
committed
fix(litellm): resolve tool compatibility with Cerebras and Groq providers
- Add _format_request_message_contents method for LiteLLM-compatible content formatting - Override format_request_messages to handle tool messages properly for Cerebras/Groq - Update structured_output method to use new message formatting - Fix content format from list to string for text messages (Cerebras/Groq requirement) - Maintain proper tool call and tool result formatting - Add comprehensive test coverage for tool message handling Fixes strands-agents#729 - Now supports agents with tools using Cerebras and Groq providers
1 parent 8caa9cb commit 54f186e

2 files changed

Lines changed: 213 additions & 2 deletions

File tree

src/strands/models/litellm.py

Lines changed: 152 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,157 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]
103103

104104
return super().format_request_message_content(content)
105105

106+
def _format_request_message_contents(self, role: str, content: ContentBlock) -> list[dict[str, Any]]:
107+
"""Format LiteLLM compatible message contents.
108+
109+
LiteLLM expects content to be a string for simple text messages, not a list of content blocks.
110+
This method flattens the content structure to be compatible with LiteLLM providers like Cerebras and Groq.
111+
112+
Args:
113+
role: Message role (e.g., "user", "assistant").
114+
content: Content block to format.
115+
116+
Returns:
117+
LiteLLM formatted message contents.
118+
119+
Raises:
120+
TypeError: If the content block type cannot be converted to a LiteLLM-compatible format.
121+
"""
122+
if "text" in content:
123+
return [{"role": role, "content": content["text"]}]
124+
125+
if "image" in content:
126+
# For images, we still need to use the structured format
127+
return [{"role": role, "content": [self.format_request_message_content(content)]}]
128+
129+
if "toolUse" in content:
130+
return [
131+
{
132+
"role": role,
133+
"tool_calls": [
134+
{
135+
"function": {
136+
"name": content["toolUse"]["name"],
137+
"arguments": json.dumps(content["toolUse"]["input"]),
138+
},
139+
"id": content["toolUse"]["toolUseId"],
140+
"type": "function",
141+
}
142+
],
143+
}
144+
]
145+
146+
if "toolResult" in content:
147+
# For tool results, we need to format the content properly
148+
tool_content_parts = []
149+
for tool_content in content["toolResult"]["content"]:
150+
if "json" in tool_content:
151+
tool_content_parts.append(json.dumps(tool_content["json"]))
152+
elif "text" in tool_content:
153+
tool_content_parts.append(tool_content["text"])
154+
else:
155+
tool_content_parts.append(str(tool_content))
156+
157+
tool_content_string = " ".join(tool_content_parts)
158+
return [
159+
{
160+
"role": "tool",
161+
"tool_call_id": content["toolResult"]["toolUseId"],
162+
"content": tool_content_string,
163+
}
164+
]
165+
166+
raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type")
167+
168+
@override
169+
@classmethod
170+
def format_request_messages(cls, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]:
171+
"""Format LiteLLM compatible messages array.
172+
173+
This method overrides the parent class to ensure compatibility with LiteLLM providers
174+
that expect string content instead of content block arrays.
175+
176+
Args:
177+
messages: List of message objects to be processed by the model.
178+
system_prompt: System prompt to provide context to the model.
179+
180+
Returns:
181+
A LiteLLM compatible messages array.
182+
"""
183+
formatted_messages: list[dict[str, Any]] = []
184+
185+
# Add system prompt if provided
186+
if system_prompt:
187+
formatted_messages.append({"role": "system", "content": system_prompt})
188+
189+
for message in messages:
190+
contents = message["content"]
191+
192+
# Separate different types of content
193+
text_contents = [content for content in contents if "text" in content and not any(block_type in content for block_type in ["toolResult", "toolUse"])]
194+
tool_use_contents = [content for content in contents if "toolUse" in content]
195+
tool_result_contents = [content for content in contents if "toolResult" in content]
196+
other_contents = [content for content in contents if not any(block_type in content for block_type in ["text", "toolResult", "toolUse"])]
197+
198+
# Handle text content - flatten to string for Cerebras/Groq compatibility
199+
if text_contents:
200+
if len(text_contents) == 1:
201+
# Single text content - use string format
202+
formatted_messages.append({
203+
"role": message["role"],
204+
"content": text_contents[0]["text"]
205+
})
206+
else:
207+
# Multiple text contents - concatenate
208+
combined_text = " ".join(content["text"] for content in text_contents)
209+
formatted_messages.append({
210+
"role": message["role"],
211+
"content": combined_text
212+
})
213+
214+
# Handle tool use content
215+
for content in tool_use_contents:
216+
formatted_messages.append({
217+
"role": message["role"],
218+
"tool_calls": [
219+
{
220+
"function": {
221+
"name": content["toolUse"]["name"],
222+
"arguments": json.dumps(content["toolUse"]["input"]),
223+
},
224+
"id": content["toolUse"]["toolUseId"],
225+
"type": "function",
226+
}
227+
],
228+
})
229+
230+
# Handle tool result content
231+
for content in tool_result_contents:
232+
tool_content_parts = []
233+
for tool_content in content["toolResult"]["content"]:
234+
if "json" in tool_content:
235+
tool_content_parts.append(json.dumps(tool_content["json"]))
236+
elif "text" in tool_content:
237+
tool_content_parts.append(tool_content["text"])
238+
else:
239+
tool_content_parts.append(str(tool_content))
240+
241+
tool_content_string = " ".join(tool_content_parts)
242+
formatted_messages.append({
243+
"role": "tool",
244+
"tool_call_id": content["toolResult"]["toolUseId"],
245+
"content": tool_content_string,
246+
})
247+
248+
# Handle other content types (images, etc.) - use structured format
249+
for content in other_contents:
250+
formatted_messages.append({
251+
"role": message["role"],
252+
"content": [cls.format_request_message_content(content)]
253+
})
254+
255+
return formatted_messages
256+
106257
@override
107258
async def stream(
108259
self,
@@ -200,7 +351,7 @@ async def structured_output(
200351
response = await litellm.acompletion(
201352
**self.client_args,
202353
model=self.get_config()["model_id"],
203-
messages=self.format_request(prompt, system_prompt=system_prompt)["messages"],
354+
messages=self.format_request_messages(prompt, system_prompt=system_prompt),
204355
response_format=output_model,
205356
)
206357

tests/strands/models/test_litellm.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator,
189189
expected_request = {
190190
"api_key": api_key,
191191
"model": model_id,
192-
"messages": [{"role": "user", "content": [{"text": "calculate 2+2", "type": "text"}]}],
192+
"messages": [{"role": "user", "content": "calculate 2+2"}],
193193
"stream": True,
194194
"stream_options": {"include_usage": True},
195195
"tools": [],
@@ -233,6 +233,66 @@ async def test_stream_empty(litellm_acompletion, api_key, model_id, model, agene
233233
litellm_acompletion.assert_called_once_with(**expected_request)
234234

235235

236+
@pytest.mark.asyncio
237+
async def test_format_request_messages_with_tools():
238+
"""Test that format_request_messages correctly handles tool messages for Cerebras/Groq compatibility."""
239+
messages = [
240+
{
241+
"role": "user",
242+
"content": [{"text": "What is 2+2?"}]
243+
},
244+
{
245+
"role": "assistant",
246+
"content": [
247+
{
248+
"toolUse": {
249+
"toolUseId": "call_123",
250+
"name": "calculator",
251+
"input": {"expression": "2+2"}
252+
}
253+
}
254+
]
255+
},
256+
{
257+
"role": "tool",
258+
"content": [
259+
{
260+
"toolResult": {
261+
"toolUseId": "call_123",
262+
"content": [{"text": "4"}]
263+
}
264+
}
265+
]
266+
}
267+
]
268+
269+
formatted = LiteLLMModel.format_request_messages(messages)
270+
271+
expected = [
272+
{"role": "user", "content": "What is 2+2?"},
273+
{
274+
"role": "assistant",
275+
"tool_calls": [
276+
{
277+
"function": {
278+
"name": "calculator",
279+
"arguments": '{"expression": "2+2"}'
280+
},
281+
"id": "call_123",
282+
"type": "function"
283+
}
284+
]
285+
},
286+
{
287+
"role": "tool",
288+
"tool_call_id": "call_123",
289+
"content": "4"
290+
}
291+
]
292+
293+
assert formatted == expected
294+
295+
236296
@pytest.mark.asyncio
237297
async def test_structured_output(litellm_acompletion, model, test_output_model_cls, alist):
238298
messages = [{"role": "user", "content": [{"text": "Generate a person"}]}]

0 commit comments

Comments
 (0)