Skip to content

Commit 649d2f5

Browse files
committed
fix: normalize multiple tool calls in Bedrock Converse API
- Refactored message normalization in convertMessages to properly handle all variants of multi-tool-call scenarios - Introduced aiMessageAccumulator to consolidate consecutive AI messages with text, reasoning, and multiple tool calls into a single assistant message - Introduced toolResultAccumulator to consolidate consecutive tool result messages into a single user message - Fixed ValidationException when users structure message chains differently (6 possible variants now supported) - Added comprehensive test TestConverseAPIMultipleToolCallsVariants covering all 6 user-facing message chain patterns - Added unit tests for message accumulators and conversion logic to ensure robustness - Updated existing test recordings to reflect normalized message structure This ensures the Bedrock Converse API receives properly formatted messages regardless of how users structure their multi-turn conversations with tool calls.
1 parent 1b1c761 commit 649d2f5

6 files changed

Lines changed: 929 additions & 62 deletions

llms/bedrock/bedrockllm_test.go

Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1765,6 +1765,265 @@ func TestAmazonSingleToolCallWithThinkingConverseAPI(t *testing.T) { //nolint:fu
17651765
}
17661766
}
17671767

1768+
// TestAmazonMultipleToolCallsVariantsConverseAPI tests all possible ways users can structure message chains
1769+
// with multiple tool calls, ensuring all variants are normalized correctly for Bedrock API
1770+
func TestAmazonMultipleToolCallsVariantsConverseAPI(t *testing.T) { //nolint:funlen
1771+
ctx := t.Context()
1772+
1773+
httprr.SkipIfNoCredentialsAndRecordingMissing(t, "AWS_ACCESS_KEY_ID")
1774+
1775+
rr := httprr.OpenForTest(t, http.DefaultTransport)
1776+
defer rr.Close()
1777+
1778+
if !rr.Recording() {
1779+
t.Parallel()
1780+
}
1781+
1782+
client, err := setUpTestWithTransport(rr)
1783+
if err != nil {
1784+
t.Fatal(err)
1785+
}
1786+
llm, err := bedrock.New(bedrock.WithClient(client), bedrock.WithConverseAPI())
1787+
if err != nil {
1788+
t.Fatal(err)
1789+
}
1790+
1791+
// Define two tools that are likely to be called together
1792+
tools := []llms.Tool{
1793+
{
1794+
Type: "function",
1795+
Function: &llms.FunctionDefinition{
1796+
Name: "search_web",
1797+
Description: "Search the web for information",
1798+
Parameters: map[string]any{
1799+
"type": "object",
1800+
"properties": map[string]any{
1801+
"query": map[string]any{
1802+
"type": "string",
1803+
"description": "Search query",
1804+
},
1805+
},
1806+
"required": []string{"query"},
1807+
},
1808+
},
1809+
},
1810+
{
1811+
Type: "function",
1812+
Function: &llms.FunctionDefinition{
1813+
Name: "get_current_date",
1814+
Description: "Get the current date and time",
1815+
Parameters: map[string]any{
1816+
"type": "object",
1817+
"properties": map[string]any{},
1818+
},
1819+
},
1820+
},
1821+
}
1822+
1823+
messages := []llms.MessageContent{
1824+
{
1825+
Role: llms.ChatMessageTypeHuman,
1826+
Parts: []llms.ContentPart{
1827+
llms.TextPart("Search for CVE-2020-10188 exploits and also tell me what's the current date"),
1828+
},
1829+
},
1830+
}
1831+
1832+
// First call - model should invoke both tools
1833+
resp1, err := llm.GenerateContent(ctx, messages,
1834+
llms.WithModel(bedrock.ModelAnthropicClaudeHaiku45),
1835+
llms.WithTools(tools),
1836+
llms.WithMaxTokens(8192),
1837+
llms.WithTemperature(1.0),
1838+
)
1839+
if err != nil {
1840+
t.Fatal(err)
1841+
}
1842+
1843+
if len(resp1.Choices) == 0 {
1844+
t.Fatal("Expected at least one choice in response")
1845+
}
1846+
1847+
choice1 := resp1.Choices[0]
1848+
if len(choice1.ToolCalls) < 2 {
1849+
t.Fatalf("Expected at least 2 tool calls, got %d", len(choice1.ToolCalls))
1850+
}
1851+
1852+
// Prepare common data for all variants
1853+
toolCall1 := choice1.ToolCalls[0]
1854+
toolCall2 := choice1.ToolCalls[1]
1855+
aiContent := choice1.Content
1856+
aiReasoning := choice1.Reasoning
1857+
result1 := `{"results": ["CVE-2020-10188 is a buffer overflow in telnetd"]}`
1858+
result2 := `{"date": "2026-03-15"}`
1859+
1860+
// Helper to test a variant
1861+
testVariant := func(t *testing.T, variantName string, buildMessages func() []llms.MessageContent) {
1862+
t.Run(variantName, func(t *testing.T) {
1863+
messages := buildMessages()
1864+
1865+
resp2, err := llm.GenerateContent(ctx, messages,
1866+
llms.WithModel(bedrock.ModelAnthropicClaudeHaiku45),
1867+
llms.WithTools(tools),
1868+
llms.WithMaxTokens(8192),
1869+
llms.WithTemperature(1.0),
1870+
)
1871+
if err != nil {
1872+
t.Fatalf("Variant %s failed: %v", variantName, err)
1873+
}
1874+
1875+
if len(resp2.Choices) == 0 {
1876+
t.Fatal("Expected at least one choice in second response")
1877+
}
1878+
1879+
if !strings.Contains(resp2.Choices[0].Content, "CVE-2020-10188") {
1880+
t.Errorf("Response should mention CVE-2020-10188, got: %s", resp2.Choices[0].Content)
1881+
}
1882+
})
1883+
}
1884+
1885+
// Variant 1: Content separate + all tool calls together + all tool results together
1886+
testVariant(t, "content_separate_toolcalls_together_results_together", func() []llms.MessageContent {
1887+
msgs := append([]llms.MessageContent{}, messages...)
1888+
msgs = append(msgs,
1889+
llms.MessageContent{
1890+
Role: llms.ChatMessageTypeAI,
1891+
Parts: []llms.ContentPart{llms.TextPartWithReasoning(aiContent, aiReasoning)},
1892+
},
1893+
llms.MessageContent{
1894+
Role: llms.ChatMessageTypeAI,
1895+
Parts: []llms.ContentPart{toolCall1, toolCall2},
1896+
},
1897+
llms.MessageContent{
1898+
Role: llms.ChatMessageTypeTool,
1899+
Parts: []llms.ContentPart{
1900+
llms.ToolCallResponse{ToolCallID: toolCall1.ID, Name: toolCall1.FunctionCall.Name, Content: result1},
1901+
llms.ToolCallResponse{ToolCallID: toolCall2.ID, Name: toolCall2.FunctionCall.Name, Content: result2},
1902+
},
1903+
},
1904+
)
1905+
return msgs
1906+
})
1907+
1908+
// Variant 2: Content separate + tool calls separate + tool results together
1909+
testVariant(t, "content_separate_toolcalls_separate_results_together", func() []llms.MessageContent {
1910+
msgs := append([]llms.MessageContent{}, messages...)
1911+
msgs = append(msgs,
1912+
llms.MessageContent{
1913+
Role: llms.ChatMessageTypeAI,
1914+
Parts: []llms.ContentPart{llms.TextPartWithReasoning(aiContent, aiReasoning)},
1915+
},
1916+
llms.MessageContent{
1917+
Role: llms.ChatMessageTypeAI,
1918+
Parts: []llms.ContentPart{toolCall1},
1919+
},
1920+
llms.MessageContent{
1921+
Role: llms.ChatMessageTypeAI,
1922+
Parts: []llms.ContentPart{toolCall2},
1923+
},
1924+
llms.MessageContent{
1925+
Role: llms.ChatMessageTypeTool,
1926+
Parts: []llms.ContentPart{
1927+
llms.ToolCallResponse{ToolCallID: toolCall1.ID, Name: toolCall1.FunctionCall.Name, Content: result1},
1928+
llms.ToolCallResponse{ToolCallID: toolCall2.ID, Name: toolCall2.FunctionCall.Name, Content: result2},
1929+
},
1930+
},
1931+
)
1932+
return msgs
1933+
})
1934+
1935+
// Variant 3: Content separate + tool calls separate + tool results separate
1936+
testVariant(t, "content_separate_toolcalls_separate_results_separate", func() []llms.MessageContent {
1937+
msgs := append([]llms.MessageContent{}, messages...)
1938+
msgs = append(msgs,
1939+
llms.MessageContent{
1940+
Role: llms.ChatMessageTypeAI,
1941+
Parts: []llms.ContentPart{llms.TextPartWithReasoning(aiContent, aiReasoning)},
1942+
},
1943+
llms.MessageContent{
1944+
Role: llms.ChatMessageTypeAI,
1945+
Parts: []llms.ContentPart{toolCall1},
1946+
},
1947+
llms.MessageContent{
1948+
Role: llms.ChatMessageTypeAI,
1949+
Parts: []llms.ContentPart{toolCall2},
1950+
},
1951+
llms.MessageContent{
1952+
Role: llms.ChatMessageTypeTool,
1953+
Parts: []llms.ContentPart{llms.ToolCallResponse{ToolCallID: toolCall1.ID, Name: toolCall1.FunctionCall.Name, Content: result1}},
1954+
},
1955+
llms.MessageContent{
1956+
Role: llms.ChatMessageTypeTool,
1957+
Parts: []llms.ContentPart{llms.ToolCallResponse{ToolCallID: toolCall2.ID, Name: toolCall2.FunctionCall.Name, Content: result2}},
1958+
},
1959+
)
1960+
return msgs
1961+
})
1962+
1963+
// Variant 4: Content + all tool calls together + tool results together
1964+
testVariant(t, "content_with_toolcalls_together_results_together", func() []llms.MessageContent {
1965+
msgs := append([]llms.MessageContent{}, messages...)
1966+
msgs = append(msgs,
1967+
llms.MessageContent{
1968+
Role: llms.ChatMessageTypeAI,
1969+
Parts: []llms.ContentPart{llms.TextPartWithReasoning(aiContent, aiReasoning), toolCall1, toolCall2},
1970+
},
1971+
llms.MessageContent{
1972+
Role: llms.ChatMessageTypeTool,
1973+
Parts: []llms.ContentPart{
1974+
llms.ToolCallResponse{ToolCallID: toolCall1.ID, Name: toolCall1.FunctionCall.Name, Content: result1},
1975+
llms.ToolCallResponse{ToolCallID: toolCall2.ID, Name: toolCall2.FunctionCall.Name, Content: result2},
1976+
},
1977+
},
1978+
)
1979+
return msgs
1980+
})
1981+
1982+
// Variant 5: Content + tool calls separate + tool results separate
1983+
testVariant(t, "content_with_toolcall1_separate_toolcall2_results_separate", func() []llms.MessageContent {
1984+
msgs := append([]llms.MessageContent{}, messages...)
1985+
msgs = append(msgs,
1986+
llms.MessageContent{
1987+
Role: llms.ChatMessageTypeAI,
1988+
Parts: []llms.ContentPart{llms.TextPartWithReasoning(aiContent, aiReasoning), toolCall1},
1989+
},
1990+
llms.MessageContent{
1991+
Role: llms.ChatMessageTypeAI,
1992+
Parts: []llms.ContentPart{toolCall2},
1993+
},
1994+
llms.MessageContent{
1995+
Role: llms.ChatMessageTypeTool,
1996+
Parts: []llms.ContentPart{llms.ToolCallResponse{ToolCallID: toolCall1.ID, Name: toolCall1.FunctionCall.Name, Content: result1}},
1997+
},
1998+
llms.MessageContent{
1999+
Role: llms.ChatMessageTypeTool,
2000+
Parts: []llms.ContentPart{llms.ToolCallResponse{ToolCallID: toolCall2.ID, Name: toolCall2.FunctionCall.Name, Content: result2}},
2001+
},
2002+
)
2003+
return msgs
2004+
})
2005+
2006+
// Variant 6: Content + all tool calls together + tool results separate
2007+
testVariant(t, "content_with_toolcalls_together_results_separate", func() []llms.MessageContent {
2008+
msgs := append([]llms.MessageContent{}, messages...)
2009+
msgs = append(msgs,
2010+
llms.MessageContent{
2011+
Role: llms.ChatMessageTypeAI,
2012+
Parts: []llms.ContentPart{llms.TextPartWithReasoning(aiContent, aiReasoning), toolCall1, toolCall2},
2013+
},
2014+
llms.MessageContent{
2015+
Role: llms.ChatMessageTypeTool,
2016+
Parts: []llms.ContentPart{llms.ToolCallResponse{ToolCallID: toolCall1.ID, Name: toolCall1.FunctionCall.Name, Content: result1}},
2017+
},
2018+
llms.MessageContent{
2019+
Role: llms.ChatMessageTypeTool,
2020+
Parts: []llms.ContentPart{llms.ToolCallResponse{ToolCallID: toolCall2.ID, Name: toolCall2.FunctionCall.Name, Content: result2}},
2021+
},
2022+
)
2023+
return msgs
2024+
})
2025+
}
2026+
17682027
// TestAmazonSequentialToolCallsWithThinkingConverseAPI tests sequential tool calls with thinking
17692028
func TestAmazonSequentialToolCallsWithThinkingConverseAPI(t *testing.T) { //nolint:funlen
17702029
ctx := t.Context()

0 commit comments

Comments
 (0)