Skip to content

Commit 5dd6c66

Browse files
committed
fix: cache tool call names in streaming to handle providers with incomplete chunks
- Added toolCallNameCache map in combineStreamingChatResponse to store tool call names by ID - Modified updateToolCall to accept nameCache parameter and restore missing names from cache - Fixed streaming callback errors when providers omit function names in subsequent chunks (e.g., GPT-4.1 via OpenRouter) - Added unit tests for name caching logic covering first chunk caching and subsequent chunk restoration - Added integration tests for different streaming formats (GPT-4 style multi-chunk, Gemini style single-chunk, parallel tool calls) This ensures streaming tool calls work correctly with all provider formats, including those that send function names only in the first chunk.
1 parent 649d2f5 commit 5dd6c66

2 files changed

Lines changed: 231 additions & 4 deletions

File tree

llms/openai/internal/openaiclient/chat.go

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -691,8 +691,9 @@ func combineStreamingChatResponse(
691691
defer streaming.CallWithDone(ctx, payload.StreamingFunc) //nolint:errcheck
692692

693693
var (
694-
response ChatCompletionResponse
695-
splitters []reasoning.ChunkContentSplitter
694+
response ChatCompletionResponse
695+
splitters []reasoning.ChunkContentSplitter
696+
toolCallNameCache = make(map[string]string) // Cache tool call names by ID for streaming
696697
)
697698

698699
for streamResponse := range responseChan {
@@ -748,7 +749,7 @@ func combineStreamingChatResponse(
748749
}
749750

750751
for _, toolCall := range choice.Delta.ToolCalls {
751-
updateToolCall(&responseChoice.Message, toolCall)
752+
updateToolCall(&responseChoice.Message, toolCall, toolCallNameCache)
752753

753754
toolCall := streaming.NewToolCall(toolCall.ID, toolCall.Function.Name, toolCall.Function.Arguments)
754755
if err := streaming.CallWithToolCall(ctx, payload.StreamingFunc, toolCall); err != nil {
@@ -805,7 +806,7 @@ func updateFunctionCall(message *ChatMessage, functionCall *FunctionCall) {
805806
}
806807
}
807808

808-
func updateToolCall(message *ChatMessage, delta *StreamedToolCall) {
809+
func updateToolCall(message *ChatMessage, delta *StreamedToolCall, nameCache map[string]string) {
809810
if delta == nil {
810811
return
811812
}
@@ -837,6 +838,11 @@ func updateToolCall(message *ChatMessage, delta *StreamedToolCall) {
837838
toolCall.Type = delta.Type
838839
toolCall.Function.Name = delta.Function.Name
839840
toolCall.Function.Arguments = delta.Function.Arguments
841+
842+
// Cache the tool call name by ID for subsequent chunks
843+
if nameCache != nil && delta.ID != "" {
844+
nameCache[delta.ID] = delta.Function.Name
845+
}
840846
}
841847

842848
// For next delta chunks, append arguments to the current tool call
@@ -847,6 +853,12 @@ func updateToolCall(message *ChatMessage, delta *StreamedToolCall) {
847853
delta.Function.Name = toolCall.Function.Name
848854
delta.ID = toolCall.ID
849855
delta.Type = toolCall.Type
856+
} else if delta.Function.Name == "" && nameCache != nil {
857+
// If ID is present but name is missing (some providers don't send name in subsequent chunks),
858+
// restore the name from cache
859+
if cachedName, ok := nameCache[delta.ID]; ok {
860+
delta.Function.Name = cachedName
861+
}
850862
}
851863
}
852864

llms/openai/internal/openaiclient/chat_sse_test.go

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ import (
77
"net/http"
88
"testing"
99

10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
1012
"github.com/vxcontrol/langchaingo/llms/streaming"
1113
)
1214

@@ -81,3 +83,216 @@ data: [DONE]`,
8183
})
8284
}
8385
}
86+
87+
func TestParseStreamingChatResponse_ToolCallNameCaching(t *testing.T) {
88+
ctx := context.Background()
89+
t.Parallel()
90+
91+
testCases := []struct {
92+
name string
93+
body string
94+
expectedToolCallName string
95+
expectedArguments string
96+
checkStreaming bool
97+
}{
98+
{
99+
name: "gpt4_style_multiple_chunks_without_name",
100+
body: `data: {"id":"1","object":"chat.completion.chunk","created":1234567890,"model":"test","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}
101+
data: {"id":"1","object":"chat.completion.chunk","created":1234567890,"model":"test","choices":[{"index":0,"delta":{"tool_calls":[{"id":"call_123","index":0,"type":"function","function":{"name":"getCurrentWeather","arguments":""}}]},"finish_reason":null}]}
102+
data: {"id":"1","object":"chat.completion.chunk","created":1234567890,"model":"test","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"type":"function","function":{"arguments":"{"}}]},"finish_reason":null}]}
103+
data: {"id":"1","object":"chat.completion.chunk","created":1234567890,"model":"test","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"type":"function","function":{"arguments":"\"location\":"}}]},"finish_reason":null}]}
104+
data: {"id":"1","object":"chat.completion.chunk","created":1234567890,"model":"test","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"type":"function","function":{"arguments":"\"Boston\"}"}}]},"finish_reason":null}]}
105+
data: {"id":"1","object":"chat.completion.chunk","created":1234567890,"model":"test","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}
106+
data: [DONE]`,
107+
expectedToolCallName: "getCurrentWeather",
108+
expectedArguments: `{"location":"Boston"}`,
109+
checkStreaming: true,
110+
},
111+
{
112+
name: "gemini_style_single_chunk_with_name",
113+
body: `data: {"id":"1","object":"chat.completion.chunk","created":1234567890,"model":"test","choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[{"index":0,"id":"tool_getCurrentWeather_123","type":"function","function":{"name":"getCurrentWeather","arguments":"{\"location\":\"Boston\"}"}}]},"finish_reason":"tool_calls"}]}
114+
data: [DONE]`,
115+
expectedToolCallName: "getCurrentWeather",
116+
expectedArguments: `{"location":"Boston"}`,
117+
checkStreaming: false,
118+
},
119+
{
120+
name: "parallel_tool_calls_without_names_in_subsequent_chunks",
121+
body: `data: {"id":"1","object":"chat.completion.chunk","created":1234567890,"model":"test","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}
122+
data: {"id":"1","object":"chat.completion.chunk","created":1234567890,"model":"test","choices":[{"index":0,"delta":{"tool_calls":[{"id":"call_1","index":0,"type":"function","function":{"name":"getCurrentWeather","arguments":""}}]},"finish_reason":null}]}
123+
data: {"id":"1","object":"chat.completion.chunk","created":1234567890,"model":"test","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"type":"function","function":{"arguments":"{"}}]},"finish_reason":null}]}
124+
data: {"id":"1","object":"chat.completion.chunk","created":1234567890,"model":"test","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"type":"function","function":{"arguments":"\"location\":\"Boston\"}"}}]},"finish_reason":null}]}
125+
data: {"id":"1","object":"chat.completion.chunk","created":1234567890,"model":"test","choices":[{"index":0,"delta":{"tool_calls":[{"id":"call_2","index":1,"type":"function","function":{"name":"getCurrentTime","arguments":""}}]},"finish_reason":null}]}
126+
data: {"id":"1","object":"chat.completion.chunk","created":1234567890,"model":"test","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"type":"function","function":{"arguments":"{"}}]},"finish_reason":null}]}
127+
data: {"id":"1","object":"chat.completion.chunk","created":1234567890,"model":"test","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"type":"function","function":{"arguments":"\"location\":\"Boston\"}"}}]},"finish_reason":null}]}
128+
data: {"id":"1","object":"chat.completion.chunk","created":1234567890,"model":"test","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}
129+
data: [DONE]`,
130+
expectedToolCallName: "getCurrentWeather", // Check first tool call
131+
expectedArguments: `{"location":"Boston"}`,
132+
checkStreaming: true,
133+
},
134+
}
135+
136+
for _, tc := range testCases {
137+
t.Run(tc.name, func(t *testing.T) {
138+
t.Parallel()
139+
140+
var streamedToolCalls []streaming.ToolCall
141+
142+
r := &http.Response{
143+
StatusCode: http.StatusOK,
144+
Body: io.NopCloser(bytes.NewBufferString(tc.body)),
145+
}
146+
147+
req := &ChatRequest{
148+
StreamingFunc: func(_ context.Context, chunk streaming.Chunk) error {
149+
if chunk.Type == streaming.ChunkTypeToolCall {
150+
streamedToolCalls = append(streamedToolCalls, chunk.ToolCall)
151+
}
152+
return nil
153+
},
154+
}
155+
156+
resp, err := parseStreamingChatResponse(ctx, r, req)
157+
require.NoError(t, err)
158+
require.NotNil(t, resp)
159+
require.Greater(t, len(resp.Choices), 0)
160+
require.Greater(t, len(resp.Choices[0].Message.ToolCalls), 0)
161+
162+
// Check final accumulated tool call
163+
toolCall := resp.Choices[0].Message.ToolCalls[0]
164+
assert.Equal(t, tc.expectedToolCallName, toolCall.Function.Name)
165+
assert.Equal(t, tc.expectedArguments, toolCall.Function.Arguments)
166+
167+
// Check streaming callbacks if needed
168+
if tc.checkStreaming {
169+
require.Greater(t, len(streamedToolCalls), 0, "should have streamed tool calls")
170+
// All streamed tool calls should have the name filled in
171+
for i, streamedTC := range streamedToolCalls {
172+
assert.NotEmpty(t, streamedTC.Name, "streamed tool call %d should have name", i)
173+
}
174+
}
175+
})
176+
}
177+
}
178+
179+
func TestUpdateToolCall_NameCaching(t *testing.T) {
180+
t.Parallel()
181+
182+
t.Run("cache_name_on_first_chunk", func(t *testing.T) {
183+
t.Parallel()
184+
185+
message := &ChatMessage{}
186+
nameCache := make(map[string]string)
187+
index := 0
188+
189+
delta := &StreamedToolCall{
190+
ID: "call_123",
191+
Type: "function",
192+
Index: &index,
193+
Function: ToolFunction{
194+
Name: "testFunction",
195+
Arguments: `{"arg":"value"}`,
196+
},
197+
}
198+
199+
updateToolCall(message, delta, nameCache)
200+
201+
assert.Equal(t, 1, len(message.ToolCalls))
202+
assert.Equal(t, "testFunction", message.ToolCalls[0].Function.Name)
203+
assert.Equal(t, "testFunction", nameCache["call_123"])
204+
})
205+
206+
t.Run("restore_name_from_cache_on_subsequent_chunks", func(t *testing.T) {
207+
t.Parallel()
208+
209+
message := &ChatMessage{
210+
ToolCalls: []ToolCall{
211+
{
212+
ID: "call_123",
213+
Type: "function",
214+
Function: ToolFunction{
215+
Name: "testFunction",
216+
Arguments: `{"arg":`,
217+
},
218+
},
219+
},
220+
}
221+
nameCache := map[string]string{
222+
"call_123": "testFunction",
223+
}
224+
index := 0
225+
226+
// Subsequent chunk without ID and name (like GPT-4.1 style)
227+
delta := &StreamedToolCall{
228+
ID: "", // No ID in subsequent chunk
229+
Type: "function",
230+
Index: &index,
231+
Function: ToolFunction{
232+
Name: "", // No name in subsequent chunk
233+
Arguments: `"value"}`,
234+
},
235+
}
236+
237+
updateToolCall(message, delta, nameCache)
238+
239+
// Name should be restored from the message's tool call
240+
assert.Equal(t, "testFunction", delta.Function.Name)
241+
assert.Equal(t, "call_123", delta.ID)
242+
assert.Equal(t, "testFunction", message.ToolCalls[0].Function.Name)
243+
assert.Equal(t, `{"arg":"value"}`, message.ToolCalls[0].Function.Arguments)
244+
})
245+
246+
t.Run("handle_multiple_tool_calls_with_cache", func(t *testing.T) {
247+
t.Parallel()
248+
249+
message := &ChatMessage{}
250+
nameCache := make(map[string]string)
251+
252+
// First tool call - first chunk
253+
index0 := 0
254+
delta1 := &StreamedToolCall{
255+
ID: "call_1",
256+
Type: "function",
257+
Index: &index0,
258+
Function: ToolFunction{
259+
Name: "function1",
260+
Arguments: "",
261+
},
262+
}
263+
updateToolCall(message, delta1, nameCache)
264+
265+
// Second tool call - first chunk
266+
index1 := 1
267+
delta2 := &StreamedToolCall{
268+
ID: "call_2",
269+
Type: "function",
270+
Index: &index1,
271+
Function: ToolFunction{
272+
Name: "function2",
273+
Arguments: "",
274+
},
275+
}
276+
updateToolCall(message, delta2, nameCache)
277+
278+
assert.Equal(t, 2, len(nameCache))
279+
assert.Equal(t, "function1", nameCache["call_1"])
280+
assert.Equal(t, "function2", nameCache["call_2"])
281+
282+
// First tool call - subsequent chunk without ID and name
283+
delta3 := &StreamedToolCall{
284+
ID: "", // No ID in subsequent chunk
285+
Type: "function",
286+
Index: &index0,
287+
Function: ToolFunction{
288+
Name: "",
289+
Arguments: `{"a":"b"}`,
290+
},
291+
}
292+
updateToolCall(message, delta3, nameCache)
293+
294+
assert.Equal(t, "function1", delta3.Function.Name)
295+
assert.Equal(t, "call_1", delta3.ID)
296+
assert.Equal(t, `{"a":"b"}`, message.ToolCalls[0].Function.Arguments)
297+
})
298+
}

0 commit comments

Comments
 (0)