Skip to content

Commit fda1443

Browse files
xgopilotphantom5099
andcommitted
test(provider): 增强 openaicompat 子包覆盖率
Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: phantom5099 <245659304+phantom5099@users.noreply.github.com>
1 parent 8dc17f2 commit fda1443

4 files changed

Lines changed: 600 additions & 0 deletions

File tree

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
package chatcompletions
2+
3+
import (
4+
"context"
5+
"io"
6+
"strings"
7+
"testing"
8+
9+
"neo-code/internal/provider"
10+
providertypes "neo-code/internal/provider/types"
11+
)
12+
13+
type stubAssetReader struct {
14+
data map[string][]byte
15+
mime map[string]string
16+
}
17+
18+
func (s *stubAssetReader) Open(_ context.Context, assetID string) (io.ReadCloser, string, error) {
19+
content, ok := s.data[assetID]
20+
if !ok {
21+
return nil, "", io.EOF
22+
}
23+
return io.NopCloser(strings.NewReader(string(content))), s.mime[assetID], nil
24+
}
25+
26+
func TestBuildRequestUsesDefaultModelAndNormalizesTools(t *testing.T) {
27+
t.Parallel()
28+
29+
cfg := provider.RuntimeConfig{DefaultModel: "gpt-default"}
30+
payload, err := BuildRequest(context.Background(), cfg, providertypes.GenerateRequest{
31+
SystemPrompt: "system",
32+
Messages: []providertypes.Message{
33+
{
34+
Role: providertypes.RoleUser,
35+
Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello"), providertypes.NewTextPart(" world")},
36+
},
37+
},
38+
Tools: []providertypes.ToolSpec{{
39+
Name: "run",
40+
Schema: map[string]any{"type": "array"},
41+
}},
42+
})
43+
if err != nil {
44+
t.Fatalf("BuildRequest() error = %v", err)
45+
}
46+
if payload.Model != "gpt-default" {
47+
t.Fatalf("expected default model, got %q", payload.Model)
48+
}
49+
if !payload.Stream {
50+
t.Fatal("expected stream=true")
51+
}
52+
if len(payload.Messages) != 2 {
53+
t.Fatalf("expected system+user messages, got %d", len(payload.Messages))
54+
}
55+
if payload.Messages[1].Content != "hello world" {
56+
t.Fatalf("expected collapsed user content, got %+v", payload.Messages[1].Content)
57+
}
58+
if len(payload.Tools) != 1 || payload.ToolChoice != "auto" {
59+
t.Fatalf("expected one auto tool, got choice=%q tools=%d", payload.ToolChoice, len(payload.Tools))
60+
}
61+
if gotType, _ := payload.Tools[0].Function.Parameters["type"].(string); gotType != "object" {
62+
t.Fatalf("expected normalized object schema, got %q", gotType)
63+
}
64+
}
65+
66+
func TestBuildRequestAndToOpenAIMessageErrors(t *testing.T) {
67+
t.Parallel()
68+
69+
t.Run("missing model", func(t *testing.T) {
70+
t.Parallel()
71+
72+
_, err := BuildRequest(context.Background(), provider.RuntimeConfig{}, providertypes.GenerateRequest{})
73+
if err == nil || !strings.Contains(err.Error(), "model is empty") {
74+
t.Fatalf("expected model error, got %v", err)
75+
}
76+
})
77+
78+
t.Run("session asset missing reader", func(t *testing.T) {
79+
t.Parallel()
80+
81+
_, err := ToOpenAIMessage(context.Background(), providertypes.Message{
82+
Role: providertypes.RoleUser,
83+
Parts: []providertypes.ContentPart{
84+
providertypes.NewSessionAssetImagePart("asset_1", "image/png"),
85+
},
86+
}, nil)
87+
if err == nil || !strings.Contains(err.Error(), "session_asset reader is not configured") {
88+
t.Fatalf("expected missing reader error, got %v", err)
89+
}
90+
})
91+
92+
t.Run("unsupported image source", func(t *testing.T) {
93+
t.Parallel()
94+
95+
_, _, err := toOpenAIMessageWithBudget(context.Background(), providertypes.Message{
96+
Role: providertypes.RoleUser,
97+
Parts: []providertypes.ContentPart{{
98+
Kind: providertypes.ContentPartImage,
99+
Image: &providertypes.ImagePart{
100+
SourceType: "unsupported",
101+
},
102+
}},
103+
}, nil, 1024, providertypes.DefaultSessionAssetLimits())
104+
if err == nil || !strings.Contains(err.Error(), "unsupported source type") {
105+
t.Fatalf("expected unsupported source type error, got %v", err)
106+
}
107+
})
108+
}
109+
110+
func TestToOpenAIMessageMapsToolCallsAndSessionAsset(t *testing.T) {
111+
t.Parallel()
112+
113+
reader := &stubAssetReader{
114+
data: map[string][]byte{"asset_1": []byte("PNG")},
115+
mime: map[string]string{"asset_1": "image/png"},
116+
}
117+
msg, used, err := toOpenAIMessageWithBudget(context.Background(), providertypes.Message{
118+
Role: providertypes.RoleAssistant,
119+
Parts: []providertypes.ContentPart{
120+
providertypes.NewTextPart("look"),
121+
providertypes.NewSessionAssetImagePart("asset_1", "image/png"),
122+
},
123+
ToolCalls: []providertypes.ToolCall{{
124+
ID: "call_1",
125+
Name: "read_file",
126+
Arguments: "{\"path\":\"README.md\"}",
127+
}},
128+
}, reader, 1024, providertypes.DefaultSessionAssetLimits())
129+
if err != nil {
130+
t.Fatalf("toOpenAIMessageWithBudget() error = %v", err)
131+
}
132+
if used <= 0 {
133+
t.Fatalf("expected consumed session asset bytes, got %d", used)
134+
}
135+
parts, ok := msg.Content.([]MessageContentPart)
136+
if !ok || len(parts) != 2 {
137+
t.Fatalf("expected 2 multimodal parts, got %+v", msg.Content)
138+
}
139+
if parts[1].ImageURL == nil || !strings.HasPrefix(parts[1].ImageURL.URL, "data:image/png;base64,") {
140+
t.Fatalf("expected encoded data url, got %+v", parts[1].ImageURL)
141+
}
142+
if len(msg.ToolCalls) != 1 || msg.ToolCalls[0].Function.Name != "read_file" {
143+
t.Fatalf("expected mapped tool call, got %+v", msg.ToolCalls)
144+
}
145+
}
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
package chatcompletions
2+
3+
import (
4+
"context"
5+
"errors"
6+
"strings"
7+
"testing"
8+
9+
"neo-code/internal/provider"
10+
providertypes "neo-code/internal/provider/types"
11+
)
12+
13+
func TestConsumeStreamEmitsTextToolAndDone(t *testing.T) {
14+
t.Parallel()
15+
16+
sseBody := strings.Join([]string{
17+
`data: {"choices":[{"index":0,"delta":{"content":"Hi "}}]}`,
18+
"",
19+
`data: {"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_1","function":{"name":"read_file","arguments":"{\"path\":"}}]}}]}`,
20+
"",
21+
`data: {"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"README.md\"}"}}]},"finish_reason":"stop"}],"usage":{"prompt_tokens":12,"completion_tokens":8,"total_tokens":20}}`,
22+
"",
23+
"data: [DONE]",
24+
"",
25+
}, "\n")
26+
27+
events := make(chan providertypes.StreamEvent, 16)
28+
err := ConsumeStream(context.Background(), strings.NewReader(sseBody), events)
29+
if err != nil {
30+
t.Fatalf("ConsumeStream() error = %v", err)
31+
}
32+
33+
collected := drainChatEvents(events)
34+
if len(collected) != 5 {
35+
t.Fatalf("expected 5 events, got %d", len(collected))
36+
}
37+
text, err := collected[0].TextDeltaValue()
38+
if err != nil || text.Text != "Hi " {
39+
t.Fatalf("expected text delta event, got err=%v event=%+v", err, collected[0])
40+
}
41+
start, err := collected[1].ToolCallStartValue()
42+
if err != nil || start.Index != 0 || start.ID != "call_1" || start.Name != "read_file" {
43+
t.Fatalf("expected tool start event, got err=%v event=%+v", err, collected[1])
44+
}
45+
delta1, err := collected[2].ToolCallDeltaValue()
46+
if err != nil || delta1.ArgumentsDelta != "{\"path\":" {
47+
t.Fatalf("expected first tool delta, got err=%v event=%+v", err, collected[2])
48+
}
49+
delta2, err := collected[3].ToolCallDeltaValue()
50+
if err != nil || delta2.ArgumentsDelta != "\"README.md\"}" {
51+
t.Fatalf("expected second tool delta, got err=%v event=%+v", err, collected[3])
52+
}
53+
done, err := collected[4].MessageDoneValue()
54+
if err != nil {
55+
t.Fatalf("expected message done event, got err=%v", err)
56+
}
57+
if done.FinishReason != "stop" {
58+
t.Fatalf("expected stop finish reason, got %q", done.FinishReason)
59+
}
60+
if done.Usage == nil || done.Usage.TotalTokens != 20 {
61+
t.Fatalf("expected usage tokens in done event, got %+v", done.Usage)
62+
}
63+
}
64+
65+
func TestConsumeStreamErrorAndEOFBranches(t *testing.T) {
66+
t.Parallel()
67+
68+
t.Run("error payload", func(t *testing.T) {
69+
t.Parallel()
70+
71+
err := ConsumeStream(context.Background(), strings.NewReader("data: {\"error\":{\"message\":\"bad key\"}}\n\n"), make(chan providertypes.StreamEvent, 2))
72+
if err == nil || !strings.Contains(err.Error(), "bad key") {
73+
t.Fatalf("expected error payload propagation, got %v", err)
74+
}
75+
})
76+
77+
t.Run("EOF with finish_reason still emits done", func(t *testing.T) {
78+
t.Parallel()
79+
80+
sseBody := "data: {\"choices\":[{\"delta\":{},\"finish_reason\":\"stop\"}]}\n\n"
81+
events := make(chan providertypes.StreamEvent, 2)
82+
err := ConsumeStream(context.Background(), strings.NewReader(sseBody), events)
83+
if err != nil {
84+
t.Fatalf("expected graceful finish on EOF with finish_reason, got %v", err)
85+
}
86+
collected := drainChatEvents(events)
87+
if len(collected) != 1 {
88+
t.Fatalf("expected one done event, got %d", len(collected))
89+
}
90+
done, err := collected[0].MessageDoneValue()
91+
if err != nil || done.FinishReason != "stop" {
92+
t.Fatalf("expected stop done event, got err=%v event=%+v", err, collected[0])
93+
}
94+
})
95+
96+
t.Run("EOF without done marker and finish reason", func(t *testing.T) {
97+
t.Parallel()
98+
99+
err := ConsumeStream(context.Background(), strings.NewReader("data: {\"choices\":[{\"delta\":{\"content\":\"x\"}}]}\n\n"), make(chan providertypes.StreamEvent, 2))
100+
if err == nil {
101+
t.Fatal("expected interrupted error")
102+
}
103+
if !errors.Is(err, provider.ErrStreamInterrupted) {
104+
t.Fatalf("expected ErrStreamInterrupted, got %v", err)
105+
}
106+
})
107+
}
108+
109+
func TestExtractAndMergeHelpers(t *testing.T) {
110+
t.Parallel()
111+
112+
usage := providertypes.Usage{InputTokens: 1}
113+
extractStreamUsage(&usage, &StreamUsage{PromptTokens: 3, CompletionTokens: 4, TotalTokens: 7})
114+
if usage.TotalTokens != 7 || usage.InputTokens != 3 || usage.OutputTokens != 4 {
115+
t.Fatalf("unexpected usage mapping: %+v", usage)
116+
}
117+
118+
events := make(chan providertypes.StreamEvent, 4)
119+
toolCalls := map[int]*providertypes.ToolCall{}
120+
if err := mergeToolCallDelta(context.Background(), events, toolCalls, StreamToolCallDelta{
121+
Index: 0,
122+
ID: "call_1",
123+
Function: FunctionCall{
124+
Name: "edit",
125+
Arguments: "{",
126+
},
127+
}); err != nil {
128+
t.Fatalf("first mergeToolCallDelta() error = %v", err)
129+
}
130+
if err := mergeToolCallDelta(context.Background(), events, toolCalls, StreamToolCallDelta{
131+
Index: 0,
132+
Function: FunctionCall{
133+
Arguments: "}",
134+
},
135+
}); err != nil {
136+
t.Fatalf("second mergeToolCallDelta() error = %v", err)
137+
}
138+
if toolCalls[0].Arguments != "{}" {
139+
t.Fatalf("expected accumulated arguments, got %q", toolCalls[0].Arguments)
140+
}
141+
142+
collected := drainChatEvents(events)
143+
if len(collected) != 3 {
144+
t.Fatalf("expected start+2deltas events, got %d", len(collected))
145+
}
146+
if _, err := collected[0].ToolCallStartValue(); err != nil {
147+
t.Fatalf("expected first event to be tool start, got err=%v", err)
148+
}
149+
}
150+
151+
func drainChatEvents(events <-chan providertypes.StreamEvent) []providertypes.StreamEvent {
152+
out := make([]providertypes.StreamEvent, 0, len(events))
153+
for {
154+
select {
155+
case ev := <-events:
156+
out = append(out, ev)
157+
default:
158+
return out
159+
}
160+
}
161+
}

0 commit comments

Comments
 (0)