Skip to content

Commit 3bb8d8a

Browse files
xgopilotphantom5099
andcommitted
fix(provider): align openaicompat identity and responses semantics
Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: phantom5099 <245659304+phantom5099@users.noreply.github.com>
1 parent fda1443 commit 3bb8d8a

7 files changed

Lines changed: 206 additions & 21 deletions

File tree

internal/provider/identity.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,6 @@ func NormalizeProviderIdentity(identity ProviderIdentity) (ProviderIdentity, err
175175
if err != nil {
176176
return ProviderIdentity{}, err
177177
}
178-
if chatEndpointPath == "/chat/completions" {
179-
chatEndpointPath = ""
180-
}
181178
return ProviderIdentity{
182179
Driver: normalizedDriver,
183180
BaseURL: normalizedBaseURL,

internal/provider/identity_test.go

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,43 @@ func TestNormalizeProviderIdentityUsesDriverSpecificNormalization(t *testing.T)
3636
t.Fatalf("expected normalized base url %q, got %q", "https://api.example.com/v1", identity.BaseURL)
3737
}
3838
if identity.ChatEndpointPath != "" {
39-
t.Fatalf("expected default chat/completions path to be omitted from identity, got %q", identity.ChatEndpointPath)
39+
t.Fatalf("expected empty chat endpoint path to stay empty in identity, got %q", identity.ChatEndpointPath)
4040
}
4141
if identity.DiscoveryEndpointPath != "/models" {
4242
t.Fatalf("expected normalized discovery endpoint path %q, got %q", "/models", identity.DiscoveryEndpointPath)
4343
}
4444
}
4545

46+
func TestNormalizeProviderIdentityOpenAICompatPreservesChatEndpointSemanticDifference(t *testing.T) {
47+
t.Parallel()
48+
49+
directIdentity, err := NormalizeProviderIdentity(ProviderIdentity{
50+
Driver: DriverOpenAICompat,
51+
BaseURL: "https://api.example.com/v1",
52+
ChatEndpointPath: "",
53+
DiscoveryEndpointPath: "/models",
54+
})
55+
if err != nil {
56+
t.Fatalf("NormalizeProviderIdentity() direct error = %v", err)
57+
}
58+
chatIdentity, err := NormalizeProviderIdentity(ProviderIdentity{
59+
Driver: DriverOpenAICompat,
60+
BaseURL: "https://api.example.com/v1",
61+
ChatEndpointPath: "/chat/completions",
62+
DiscoveryEndpointPath: "/models",
63+
})
64+
if err != nil {
65+
t.Fatalf("NormalizeProviderIdentity() chat error = %v", err)
66+
}
67+
68+
if directIdentity.Key() == chatIdentity.Key() {
69+
t.Fatalf("expected distinct identity keys for direct mode and /chat/completions, got %q", directIdentity.Key())
70+
}
71+
if chatIdentity.ChatEndpointPath != "/chat/completions" {
72+
t.Fatalf("expected /chat/completions to be preserved, got %q", chatIdentity.ChatEndpointPath)
73+
}
74+
}
75+
4676
func TestNormalizeProviderIdentityShrinksSDKDriverFields(t *testing.T) {
4777
t.Parallel()
4878

internal/provider/openaicompat/chatcompletions/request.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,17 @@ func ToOpenAIMessage(ctx context.Context, message providertypes.Message, assetRe
9494
return msg, err
9595
}
9696

97+
// ToOpenAIMessageWithBudget 将通用 Message 转换为 OpenAI 协议消息格式,并应用会话附件预算限制。
98+
func ToOpenAIMessageWithBudget(
99+
ctx context.Context,
100+
message providertypes.Message,
101+
assetReader providertypes.SessionAssetReader,
102+
remainingAssetBudget int64,
103+
assetLimits providertypes.SessionAssetLimits,
104+
) (Message, int64, error) {
105+
return toOpenAIMessageWithBudget(ctx, message, assetReader, remainingAssetBudget, assetLimits)
106+
}
107+
97108
// toOpenAIMessageWithBudget 将通用 Message 转换为 OpenAI 协议消息格式,并记录 session_asset 消耗字节数。
98109
func toOpenAIMessageWithBudget(
99110
ctx context.Context,

internal/provider/openaicompat/responses/request.go

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,21 @@ func BuildRequest(ctx context.Context, cfg provider.RuntimeConfig, req providert
3333
payload.Instructions = req.SystemPrompt
3434
}
3535

36+
assetLimits := providertypes.NormalizeSessionAssetLimits(cfg.SessionAssetLimits)
37+
var usedSessionAssetBytes int64
3638
for _, message := range req.Messages {
37-
items, err := toResponsesInputItems(ctx, message, req.SessionAssetReader)
39+
remainingSessionAssetBytes := assetLimits.MaxSessionAssetsTotalBytes - usedSessionAssetBytes
40+
items, consumedBytes, err := toResponsesInputItems(
41+
ctx,
42+
message,
43+
req.SessionAssetReader,
44+
remainingSessionAssetBytes,
45+
assetLimits,
46+
)
3847
if err != nil {
3948
return Request{}, err
4049
}
50+
usedSessionAssetBytes += consumedBytes
4151
payload.Input = append(payload.Input, items...)
4252
}
4353

@@ -62,20 +72,28 @@ func toResponsesInputItems(
6272
ctx context.Context,
6373
message providertypes.Message,
6474
assetReader providertypes.SessionAssetReader,
65-
) ([]InputItem, error) {
66-
openaiMessage, err := chatcompletions.ToOpenAIMessage(ctx, message, assetReader)
75+
remainingAssetBudget int64,
76+
assetLimits providertypes.SessionAssetLimits,
77+
) ([]InputItem, int64, error) {
78+
openaiMessage, consumedBytes, err := chatcompletions.ToOpenAIMessageWithBudget(
79+
ctx,
80+
message,
81+
assetReader,
82+
remainingAssetBudget,
83+
assetLimits,
84+
)
6785
if err != nil {
68-
return nil, err
86+
return nil, 0, err
6987
}
7088

7189
switch strings.TrimSpace(openaiMessage.Role) {
7290
case providertypes.RoleSystem:
73-
return nil, nil
91+
return nil, consumedBytes, nil
7492
case providertypes.RoleUser, providertypes.RoleAssistant:
7593
items := make([]InputItem, 0, 1+len(openaiMessage.ToolCalls))
7694
contentParts, err := toResponsesContentParts(openaiMessage.Content)
7795
if err != nil {
78-
return nil, err
96+
return nil, 0, err
7997
}
8098
if len(contentParts) > 0 {
8199
items = append(items, InputItem{
@@ -96,23 +114,23 @@ func toResponsesInputItems(
96114
})
97115
}
98116
}
99-
return items, nil
117+
return items, consumedBytes, nil
100118
case providertypes.RoleTool:
101119
callID := strings.TrimSpace(openaiMessage.ToolCallID)
102120
if callID == "" {
103-
return nil, errors.New(errorPrefix + "tool result message requires tool_call_id")
121+
return nil, 0, errors.New(errorPrefix + "tool result message requires tool_call_id")
104122
}
105123
output, err := renderToolOutput(openaiMessage.Content)
106124
if err != nil {
107-
return nil, err
125+
return nil, 0, err
108126
}
109127
return []InputItem{{
110128
Type: "function_call_output",
111129
CallID: callID,
112130
Output: output,
113-
}}, nil
131+
}}, consumedBytes, nil
114132
default:
115-
return nil, fmt.Errorf("%sunsupported message role %q", errorPrefix, openaiMessage.Role)
133+
return nil, 0, fmt.Errorf("%sunsupported message role %q", errorPrefix, openaiMessage.Role)
116134
}
117135
}
118136

internal/provider/openaicompat/responses/request_test.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,27 @@ package responses
22

33
import (
44
"context"
5+
"io"
56
"strings"
67
"testing"
78

89
"neo-code/internal/provider"
910
providertypes "neo-code/internal/provider/types"
1011
)
1112

13+
type stubAssetReader struct {
14+
data map[string][]byte
15+
mime map[string]string
16+
}
17+
18+
func (s *stubAssetReader) Open(_ context.Context, assetID string) (io.ReadCloser, string, error) {
19+
content, ok := s.data[assetID]
20+
if !ok {
21+
return nil, "", io.EOF
22+
}
23+
return io.NopCloser(strings.NewReader(string(content))), s.mime[assetID], nil
24+
}
25+
1226
func TestBuildRequestUsesDefaultModelAndMapsMessages(t *testing.T) {
1327
t.Parallel()
1428

@@ -114,6 +128,44 @@ func TestBuildRequestValidationErrors(t *testing.T) {
114128
t.Fatalf("expected tool_call_id error, got %v", err)
115129
}
116130
})
131+
132+
t.Run("session asset total budget respects runtime limits", func(t *testing.T) {
133+
t.Parallel()
134+
135+
assetReader := &stubAssetReader{
136+
data: map[string][]byte{
137+
"asset_1": []byte("PN"),
138+
"asset_2": []byte("PN"),
139+
},
140+
mime: map[string]string{
141+
"asset_1": "image/png",
142+
"asset_2": "image/png",
143+
},
144+
}
145+
146+
_, err := BuildRequest(context.Background(), provider.RuntimeConfig{
147+
DefaultModel: "m",
148+
SessionAssetLimits: providertypes.SessionAssetLimits{
149+
MaxSessionAssetBytes: 2,
150+
MaxSessionAssetsTotalBytes: 3,
151+
},
152+
}, providertypes.GenerateRequest{
153+
Messages: []providertypes.Message{
154+
{
155+
Role: providertypes.RoleUser,
156+
Parts: []providertypes.ContentPart{providertypes.NewSessionAssetImagePart("asset_1", "image/png")},
157+
},
158+
{
159+
Role: providertypes.RoleUser,
160+
Parts: []providertypes.ContentPart{providertypes.NewSessionAssetImagePart("asset_2", "image/png")},
161+
},
162+
},
163+
SessionAssetReader: assetReader,
164+
})
165+
if err == nil || !strings.Contains(err.Error(), "session_asset total exceeds 3 bytes") {
166+
t.Fatalf("expected runtime session asset total budget error, got %v", err)
167+
}
168+
})
117169
}
118170

119171
func TestToResponsesContentPartsAndRenderToolOutput(t *testing.T) {

internal/provider/openaicompat/responses/stream.go

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ func ConsumeStream(
5858
Function: responseFunctionCall{
5959
Arguments: event.Delta,
6060
},
61+
ArgumentsMode: responseToolCallArgumentsMergeAppend,
6162
}
6263
if err := mergeToolCallDelta(ctx, events, toolCalls, delta); err != nil {
6364
return err
@@ -66,6 +67,10 @@ func ConsumeStream(
6667
if event.Item == nil || strings.TrimSpace(event.Item.Type) != "function_call" {
6768
return nil
6869
}
70+
argumentsMode := responseToolCallArgumentsMergeAppend
71+
if strings.TrimSpace(event.Type) == "response.output_item.done" {
72+
argumentsMode = responseToolCallArgumentsMergeReplace
73+
}
6974
toolIndex := resolveToolCallIndex(event.OutputIndex, event.Item.ID, itemToolCallMap, &nextToolCallSlot)
7075
toolCallID := strings.TrimSpace(event.Item.CallID)
7176
if toolCallID == "" {
@@ -78,6 +83,7 @@ func ConsumeStream(
7883
Name: strings.TrimSpace(event.Item.Name),
7984
Arguments: event.Item.Arguments,
8085
},
86+
ArgumentsMode: argumentsMode,
8187
}
8288
if err := mergeToolCallDelta(ctx, events, toolCalls, delta); err != nil {
8389
return err
@@ -194,11 +200,19 @@ type responseFunctionCall struct {
194200
}
195201

196202
type responseToolCallDelta struct {
197-
Index int
198-
ID string
199-
Function responseFunctionCall
203+
Index int
204+
ID string
205+
Function responseFunctionCall
206+
ArgumentsMode responseToolCallArgumentsMergeMode
200207
}
201208

209+
type responseToolCallArgumentsMergeMode int
210+
211+
const (
212+
responseToolCallArgumentsMergeAppend responseToolCallArgumentsMergeMode = iota
213+
responseToolCallArgumentsMergeReplace
214+
)
215+
202216
// mergeToolCallDelta 将单个 tool call 增量合并到累积状态,并在必要时发出统一事件。
203217
func mergeToolCallDelta(
204218
ctx context.Context,
@@ -227,9 +241,29 @@ func mergeToolCallDelta(
227241
}
228242

229243
if args := delta.Function.Arguments; args != "" {
230-
call.Arguments += args
231-
if err := provider.EmitToolCallDelta(ctx, events, delta.Index, call.ID, args); err != nil {
232-
return err
244+
switch delta.ArgumentsMode {
245+
case responseToolCallArgumentsMergeReplace:
246+
if call.Arguments == args {
247+
return nil
248+
}
249+
emitDelta := args
250+
if strings.HasPrefix(args, call.Arguments) {
251+
emitDelta = strings.TrimPrefix(args, call.Arguments)
252+
}
253+
call.Arguments = args
254+
if emitDelta != "" {
255+
if err := provider.EmitToolCallDelta(ctx, events, delta.Index, call.ID, emitDelta); err != nil {
256+
return err
257+
}
258+
}
259+
default:
260+
if strings.HasSuffix(call.Arguments, args) {
261+
return nil
262+
}
263+
call.Arguments += args
264+
if err := provider.EmitToolCallDelta(ctx, events, delta.Index, call.ID, args); err != nil {
265+
return err
266+
}
233267
}
234268
}
235269
return nil

internal/provider/openaicompat/responses/stream_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,49 @@ func TestConsumeStreamCompletedWithToolCalls(t *testing.T) {
6565
}
6666
}
6767

68+
func TestConsumeStreamToolCallArgumentsAddedDeltaDoneNoDuplicate(t *testing.T) {
69+
t.Parallel()
70+
71+
sseBody := strings.Join([]string{
72+
`data: {"type":"response.output_item.added","output_index":0,"item":{"type":"function_call","id":"item_1","call_id":"call_1","name":"read_file","arguments":"{\"path\":\"README.md\"}"}}`,
73+
"",
74+
`data: {"type":"response.function_call_arguments.delta","output_index":0,"item_id":"item_1","delta":"{\"path\":\"README.md\"}"}`,
75+
"",
76+
`data: {"type":"response.output_item.done","output_index":0,"item":{"type":"function_call","id":"item_1","call_id":"call_1","name":"read_file","arguments":"{\"path\":\"README.md\"}"}}`,
77+
"",
78+
`data: {"type":"response.completed","response":{"status":"completed"}}`,
79+
"",
80+
"data: [DONE]",
81+
"",
82+
}, "\n")
83+
84+
events := make(chan providertypes.StreamEvent, 16)
85+
err := ConsumeStream(context.Background(), strings.NewReader(sseBody), events)
86+
if err != nil {
87+
t.Fatalf("ConsumeStream() error = %v", err)
88+
}
89+
90+
collected := drainStreamEvents(events)
91+
if len(collected) != 3 {
92+
t.Fatalf("expected 3 events, got %d", len(collected))
93+
}
94+
95+
start, err := collected[0].ToolCallStartValue()
96+
if err != nil || start.Name != "read_file" {
97+
t.Fatalf("expected tool start event, got err=%v event=%+v", err, collected[0])
98+
}
99+
delta, err := collected[1].ToolCallDeltaValue()
100+
if err != nil {
101+
t.Fatalf("expected tool delta event, got err=%v event=%+v", err, collected[1])
102+
}
103+
if delta.ArgumentsDelta != `{"path":"README.md"}` {
104+
t.Fatalf("expected single non-duplicated arguments delta, got %q", delta.ArgumentsDelta)
105+
}
106+
if _, err := collected[2].MessageDoneValue(); err != nil {
107+
t.Fatalf("expected message done event, got err=%v event=%+v", err, collected[2])
108+
}
109+
}
110+
68111
func TestConsumeStreamIncompleteAndFailures(t *testing.T) {
69112
t.Parallel()
70113

0 commit comments

Comments
 (0)