Skip to content

Commit 5fe2d69

Browse files
xgopilotphantom5099
andcommitted
fix(provider-add): 修复 chat mode 路径联动与模型数值校验一致性
Generated with [codeagent](https://github.com/qbox/codeagent) Co-authored-by: phantom5099 <245659304+phantom5099@users.noreply.github.com>
1 parent f189bbd commit 5fe2d69

7 files changed

Lines changed: 116 additions & 14 deletions

File tree

docs/guides/adding-providers.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ discovery_endpoint_path: /models
103103
说明:
104104
105105
- `chat_api_mode` 仅 `openaicompat` 生效,可选值:`chat_completions` / `responses`。
106-
- `chat_endpoint_path` 为空或 `/` 表示直连 `base_url`,不会自动补子路径
106+
- `chat_endpoint_path` `/` 表示直连 `base_url`;为空时会按 `chat_api_mode` 自动回填默认子路径(`/chat/completions` 或 `/responses`)
107107
- `model_source: manual` 时必须提供 `models`,且会忽略 `discovery_endpoint_path`。
108108

109109
## 测试要求

internal/config/provider_custom_normalize.go

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ import (
88
providertypes "neo-code/internal/provider/types"
99
)
1010

11+
// ManualModelOptionalIntUnset 用于区分“未填写可选数值字段”和“显式输入 0”。
12+
const ManualModelOptionalIntUnset = -1
13+
1114
// NormalizeCustomProviderInput 统一归一化 custom provider 的输入字段,并执行协议/模型来源的组合校验。
1215
func NormalizeCustomProviderInput(input SaveCustomProviderInput) (SaveCustomProviderInput, error) {
1316
normalized := SaveCustomProviderInput{
@@ -40,7 +43,7 @@ func NormalizeCustomProviderInput(input SaveCustomProviderInput) (SaveCustomProv
4043
normalized.ModelSource = ModelSourceDiscover
4144
}
4245

43-
models, err := NormalizeCustomProviderModels(input.Models)
46+
models, err := normalizeCustomProviderModels(input.Models, true)
4447
if err != nil {
4548
return SaveCustomProviderInput{}, err
4649
}
@@ -109,6 +112,14 @@ func NormalizeCustomProviderInput(input SaveCustomProviderInput) (SaveCustomProv
109112

110113
// NormalizeCustomProviderModels 统一归一化 custom provider 的模型描述并校验必填字段和边界条件。
111114
func NormalizeCustomProviderModels(models []providertypes.ModelDescriptor) ([]providertypes.ModelDescriptor, error) {
115+
return normalizeCustomProviderModels(models, false)
116+
}
117+
118+
// normalizeCustomProviderModels 统一归一化 custom provider 模型列表,并在需要时兼容历史的零值省略语义。
119+
func normalizeCustomProviderModels(
120+
models []providertypes.ModelDescriptor,
121+
allowZeroAsUnset bool,
122+
) ([]providertypes.ModelDescriptor, error) {
112123
if len(models) == 0 {
113124
return nil, nil
114125
}
@@ -124,10 +135,16 @@ func NormalizeCustomProviderModels(models []providertypes.ModelDescriptor) ([]pr
124135
if name == "" {
125136
return nil, fmt.Errorf("config: models[%d].name is empty", index)
126137
}
127-
if model.ContextWindow < 0 {
138+
contextWindow := model.ContextWindow
139+
if contextWindow == ManualModelOptionalIntUnset || (allowZeroAsUnset && contextWindow == 0) {
140+
contextWindow = 0
141+
} else if contextWindow <= 0 {
128142
return nil, fmt.Errorf("config: models[%d].context_window must be greater than 0", index)
129143
}
130-
if model.MaxOutputTokens < 0 {
144+
maxOutputTokens := model.MaxOutputTokens
145+
if maxOutputTokens == ManualModelOptionalIntUnset || (allowZeroAsUnset && maxOutputTokens == 0) {
146+
maxOutputTokens = 0
147+
} else if maxOutputTokens <= 0 {
131148
return nil, fmt.Errorf("config: models[%d].max_output_tokens must be greater than 0", index)
132149
}
133150

@@ -141,8 +158,8 @@ func NormalizeCustomProviderModels(models []providertypes.ModelDescriptor) ([]pr
141158
ID: id,
142159
Name: name,
143160
Description: strings.TrimSpace(model.Description),
144-
ContextWindow: model.ContextWindow,
145-
MaxOutputTokens: model.MaxOutputTokens,
161+
ContextWindow: contextWindow,
162+
MaxOutputTokens: maxOutputTokens,
146163
CapabilityHints: model.CapabilityHints,
147164
})
148165
}

internal/config/provider_loader.go

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"sort"
1010
"strings"
1111

12+
"neo-code/internal/provider"
1213
providertypes "neo-code/internal/provider/types"
1314

1415
"gopkg.in/yaml.v3"
@@ -146,6 +147,7 @@ func customProviderModels(models []customProviderModelFile) ([]providertypes.Mod
146147
}
147148

148149
descriptors := make([]providertypes.ModelDescriptor, 0, len(models))
150+
seen := make(map[string]struct{}, len(models))
149151
for index, model := range models {
150152
id := strings.TrimSpace(model.ID)
151153
if id == "" {
@@ -157,9 +159,16 @@ func customProviderModels(models []customProviderModelFile) ([]providertypes.Mod
157159
}
158160

159161
descriptor := providertypes.ModelDescriptor{
160-
ID: id,
161-
Name: name,
162+
ID: id,
163+
Name: name,
164+
ContextWindow: ManualModelOptionalIntUnset,
165+
MaxOutputTokens: ManualModelOptionalIntUnset,
162166
}
167+
key := provider.NormalizeKey(id)
168+
if _, exists := seen[key]; exists {
169+
return nil, fmt.Errorf("models[%d].id %q is duplicated", index, id)
170+
}
171+
seen[key] = struct{}{}
163172
if model.ContextWindow != nil {
164173
if *model.ContextWindow <= 0 {
165174
return nil, fmt.Errorf("models[%d].context_window must be greater than 0", index)
@@ -174,7 +183,7 @@ func customProviderModels(models []customProviderModelFile) ([]providertypes.Mod
174183
}
175184
descriptors = append(descriptors, descriptor)
176185
}
177-
return NormalizeCustomProviderModels(descriptors)
186+
return descriptors, nil
178187
}
179188

180189
// SaveCustomProviderInput 定义自定义 Provider 的持久化字段。

internal/config/provider_test.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,33 @@ func TestCustomProviderModelsRejectsNonPositiveMaxOutputTokens(t *testing.T) {
440440
}
441441
}
442442

443+
func TestNormalizeCustomProviderModelsRejectsZeroLimits(t *testing.T) {
444+
t.Parallel()
445+
446+
_, err := NormalizeCustomProviderModels([]providertypes.ModelDescriptor{
447+
{
448+
ID: "deepseek-coder",
449+
Name: "DeepSeek Coder",
450+
ContextWindow: 0,
451+
},
452+
})
453+
if err == nil || !strings.Contains(err.Error(), "context_window") {
454+
t.Fatalf("expected context_window validation error, got %v", err)
455+
}
456+
457+
_, err = NormalizeCustomProviderModels([]providertypes.ModelDescriptor{
458+
{
459+
ID: "deepseek-coder",
460+
Name: "DeepSeek Coder",
461+
ContextWindow: ManualModelOptionalIntUnset,
462+
MaxOutputTokens: 0,
463+
},
464+
})
465+
if err == nil || !strings.Contains(err.Error(), "max_output_tokens") {
466+
t.Fatalf("expected max_output_tokens validation error, got %v", err)
467+
}
468+
}
469+
443470
func TestCustomProviderModelsRejectsDuplicateID(t *testing.T) {
444471
t.Parallel()
445472

internal/config/state/service_provider_create.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,10 @@ func parseManualModelsJSON(raw string) ([]providertypes.ModelDescriptor, error)
291291
seen[key] = struct{}{}
292292

293293
descriptor := providertypes.ModelDescriptor{
294-
ID: id,
295-
Name: name,
294+
ID: id,
295+
Name: name,
296+
ContextWindow: config.ManualModelOptionalIntUnset,
297+
MaxOutputTokens: config.ManualModelOptionalIntUnset,
296298
}
297299
if model.ContextWindow != nil {
298300
if *model.ContextWindow <= 0 {

internal/tui/core/app/update.go

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3066,6 +3066,7 @@ func (a *App) handleProviderAddFormInput(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
30663066
a.providerAddForm.ModelSource = a.providerAddForm.ModelSources[currentIdx]
30673067
clampProviderAddStep(a.providerAddForm)
30683068
} else if currentProviderAddField(a.providerAddForm) == providerAddFieldChatAPIMode {
3069+
previousMode := a.providerAddForm.ChatAPIMode
30693070
currentIdx := 0
30703071
for i, mode := range a.providerAddForm.ChatAPIModes {
30713072
if mode == a.providerAddForm.ChatAPIMode {
@@ -3075,6 +3076,7 @@ func (a *App) handleProviderAddFormInput(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
30753076
}
30763077
currentIdx = (currentIdx - 1 + len(a.providerAddForm.ChatAPIModes)) % len(a.providerAddForm.ChatAPIModes)
30773078
a.providerAddForm.ChatAPIMode = a.providerAddForm.ChatAPIModes[currentIdx]
3079+
syncProviderAddOpenAICompatModeDefaults(a.providerAddForm, previousMode)
30783080
clampProviderAddStep(a.providerAddForm)
30793081
}
30803082
return a, nil
@@ -3106,6 +3108,7 @@ func (a *App) handleProviderAddFormInput(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
31063108
a.providerAddForm.ModelSource = a.providerAddForm.ModelSources[currentIdx]
31073109
clampProviderAddStep(a.providerAddForm)
31083110
} else if currentProviderAddField(a.providerAddForm) == providerAddFieldChatAPIMode {
3111+
previousMode := a.providerAddForm.ChatAPIMode
31093112
currentIdx := 0
31103113
for i, mode := range a.providerAddForm.ChatAPIModes {
31113114
if mode == a.providerAddForm.ChatAPIMode {
@@ -3115,6 +3118,7 @@ func (a *App) handleProviderAddFormInput(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
31153118
}
31163119
currentIdx = (currentIdx + 1) % len(a.providerAddForm.ChatAPIModes)
31173120
a.providerAddForm.ChatAPIMode = a.providerAddForm.ChatAPIModes[currentIdx]
3121+
syncProviderAddOpenAICompatModeDefaults(a.providerAddForm, previousMode)
31183122
clampProviderAddStep(a.providerAddForm)
31193123
}
31203124
return a, nil
@@ -3231,6 +3235,20 @@ func providerAddDefaultOpenAICompatChatEndpointPath(chatAPIMode string) string {
32313235
return "/chat/completions"
32323236
}
32333237

3238+
// syncProviderAddOpenAICompatModeDefaults 在切换 chat_api_mode 时同步默认 chat endpoint,避免默认值错配。
3239+
func syncProviderAddOpenAICompatModeDefaults(form *providerAddFormState, previousMode string) {
3240+
if form == nil || provider.NormalizeProviderDriver(form.Driver) != provider.DriverOpenAICompat {
3241+
return
3242+
}
3243+
3244+
currentPath := strings.TrimSpace(form.ChatEndpointPath)
3245+
previousDefaultPath := providerAddDefaultOpenAICompatChatEndpointPath(previousMode)
3246+
if currentPath != "" && currentPath != previousDefaultPath {
3247+
return
3248+
}
3249+
form.ChatEndpointPath = providerAddDefaultOpenAICompatChatEndpointPath(form.ChatAPIMode)
3250+
}
3251+
32343252
// providerAddDefaultBaseURL 返回 provider add 表单的驱动默认 base URL。
32353253
func providerAddDefaultBaseURL(driver string) string {
32363254
switch provider.NormalizeProviderDriver(driver) {
@@ -3419,11 +3437,19 @@ func parseProviderAddManualModelsJSON(raw string) ([]providertypes.ModelDescript
34193437
}
34203438

34213439
descriptors := make([]providertypes.ModelDescriptor, 0, len(models))
3440+
seen := make(map[string]struct{}, len(models))
34223441
for _, model := range models {
34233442
descriptor := providertypes.ModelDescriptor{
3424-
ID: strings.TrimSpace(model.ID),
3425-
Name: strings.TrimSpace(model.Name),
3443+
ID: strings.TrimSpace(model.ID),
3444+
Name: strings.TrimSpace(model.Name),
3445+
ContextWindow: config.ManualModelOptionalIntUnset,
3446+
MaxOutputTokens: config.ManualModelOptionalIntUnset,
3447+
}
3448+
key := provider.NormalizeKey(descriptor.ID)
3449+
if _, exists := seen[key]; exists {
3450+
return nil, fmt.Errorf("parse manual model json: models.id %q is duplicated", descriptor.ID)
34263451
}
3452+
seen[key] = struct{}{}
34273453
if model.ContextWindow != nil {
34283454
if *model.ContextWindow <= 0 {
34293455
return nil, fmt.Errorf("parse manual model json: models.context_window must be greater than 0")
@@ -3438,7 +3464,7 @@ func parseProviderAddManualModelsJSON(raw string) ([]providertypes.ModelDescript
34383464
}
34393465
descriptors = append(descriptors, descriptor)
34403466
}
3441-
return config.NormalizeCustomProviderModels(descriptors)
3467+
return descriptors, nil
34423468
}
34433469

34443470
// sanitizeProviderAddInputRunes 过滤 provider 表单输入中的控制字符,避免不可见字符污染配置字段。

internal/tui/core/app/update_test.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3471,6 +3471,27 @@ func TestSlashSelectionAndProviderAddUtilityBranches(t *testing.T) {
34713471
app.handleProviderAddResultMsg(providerAddResultMsg{Name: "unused"})
34723472
}
34733473

3474+
func TestSyncProviderAddOpenAICompatModeDefaults(t *testing.T) {
3475+
t.Parallel()
3476+
3477+
form := &providerAddFormState{
3478+
Driver: provider.DriverOpenAICompat,
3479+
ChatAPIMode: provider.ChatAPIModeResponses,
3480+
ChatEndpointPath: "/chat/completions",
3481+
}
3482+
syncProviderAddOpenAICompatModeDefaults(form, provider.ChatAPIModeChatCompletions)
3483+
if form.ChatEndpointPath != "/responses" {
3484+
t.Fatalf("expected default endpoint to follow responses mode, got %q", form.ChatEndpointPath)
3485+
}
3486+
3487+
form.ChatAPIMode = provider.ChatAPIModeChatCompletions
3488+
form.ChatEndpointPath = "/custom/chat"
3489+
syncProviderAddOpenAICompatModeDefaults(form, provider.ChatAPIModeResponses)
3490+
if form.ChatEndpointPath != "/custom/chat" {
3491+
t.Fatalf("expected custom endpoint unchanged, got %q", form.ChatEndpointPath)
3492+
}
3493+
}
3494+
34743495
func TestRunProviderAddFlowDeadlineExceededBranch(t *testing.T) {
34753496
service := stubProviderService{
34763497
createErr: context.DeadlineExceeded,

0 commit comments

Comments
 (0)