Skip to content

Commit 4d49d5e

Browse files
authored
Merge pull request #337 from phantom5099/main
feat(TUI):收敛多模态输入链路:runtime 单入口提交 + session 归一化落地,清理 TUI 过渡职责并修复 schema/目录行为
2 parents 6a6e7eb + a7e61a1 commit 4d49d5e

29 files changed

Lines changed: 2566 additions & 331 deletions

internal/app/bootstrap.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ func BuildRuntime(ctx context.Context, opts BootstrapOptions) (RuntimeBundle, er
172172
contextBuilder,
173173
)
174174
runtimeSvc.SetSessionAssetStore(sessionStore)
175+
runtimeSvc.SetUserInputPreparer(agentruntime.NewSessionInputPreparer(sessionStore, sessionStore))
175176
runtimeSvc.SetSkillsRegistry(buildSkillsRegistry(ctx, loader.BaseDir()))
176177
runtimeSvc.SetAutoCompactThresholdResolver(runtimeAutoCompactThresholdResolverFunc(
177178
func(ctx context.Context, cfg config.Config) (int, error) {

internal/config/loader_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,12 +1059,12 @@ func TestLoadCustomProvidersReadDirAndStatErrors(t *testing.T) {
10591059
}
10601060
defer func() { _ = os.Chmod(providersPath, 0o755) }()
10611061

1062-
_, err := loadCustomProviders(baseDir)
1063-
if err == nil {
1064-
t.Fatal("expected read providers dir error")
1062+
providers, err := loadCustomProviders(baseDir)
1063+
if err != nil {
1064+
t.Fatalf("expected read providers dir fallback, got %v", err)
10651065
}
1066-
if !strings.Contains(err.Error(), "read providers dir") {
1067-
t.Fatalf("expected read providers dir error, got %v", err)
1066+
if len(providers) != 0 {
1067+
t.Fatalf("expected empty providers on read fallback, got %d", len(providers))
10681068
}
10691069
})
10701070

internal/config/provider_loader.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ func loadCustomProviders(baseDir string) ([]ProviderConfig, error) {
6767
if os.IsNotExist(err) {
6868
return nil, nil
6969
}
70-
return nil, fmt.Errorf("config: read providers dir: %w", err)
70+
return nil, nil
7171
}
7272

7373
sort.Slice(entries, func(i, j int) bool {

internal/provider/openaicompat/chatcompletions/provider_test.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,62 @@ func TestNewAndBuildRequest(t *testing.T) {
7171
t.Fatalf("unexpected tools: %+v", payload.Tools)
7272
}
7373

74+
toolSchemaWithTopLevelCombinator := map[string]any{
75+
"type": "object",
76+
"properties": map[string]any{
77+
"action": map[string]any{"type": "string"},
78+
},
79+
"oneOf": []any{
80+
map[string]any{"required": []string{"action"}},
81+
},
82+
}
83+
sanitizedPayload, err := BuildRequest(context.Background(), testCfg("https://api.example.com/v1", "gpt-4.1", "test-key"), providertypes.GenerateRequest{
84+
Messages: []providertypes.Message{
85+
{Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}},
86+
},
87+
Tools: []providertypes.ToolSpec{{
88+
Name: "todo_write",
89+
Description: "write todos",
90+
Schema: toolSchemaWithTopLevelCombinator,
91+
}},
92+
})
93+
if err != nil {
94+
t.Fatalf("BuildRequest() sanitize schema error = %v", err)
95+
}
96+
gotSchema := sanitizedPayload.Tools[0].Function.Parameters
97+
if gotSchema["type"] != "object" {
98+
t.Fatalf("expected sanitized schema type object, got %+v", gotSchema["type"])
99+
}
100+
if _, ok := gotSchema["oneOf"]; !ok {
101+
t.Fatalf("expected top-level oneOf to be preserved, got %+v", gotSchema)
102+
}
103+
if _, ok := toolSchemaWithTopLevelCombinator["oneOf"]; !ok {
104+
t.Fatalf("expected original schema not to be mutated")
105+
}
106+
107+
downgradedPayload, err := BuildRequest(context.Background(), testCfg("https://api.example.com/v1", "gpt-4.1", "test-key"), providertypes.GenerateRequest{
108+
Messages: []providertypes.Message{
109+
{Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}},
110+
},
111+
Tools: []providertypes.ToolSpec{{
112+
Name: "non_object_schema",
113+
Description: "schema root is string",
114+
Schema: map[string]any{
115+
"type": "string",
116+
},
117+
}},
118+
})
119+
if err != nil {
120+
t.Fatalf("BuildRequest() downgrade schema error = %v", err)
121+
}
122+
downgradedSchema := downgradedPayload.Tools[0].Function.Parameters
123+
if downgradedSchema["type"] != "object" {
124+
t.Fatalf("expected downgraded schema type object, got %+v", downgradedSchema["type"])
125+
}
126+
if _, ok := downgradedSchema["x-neocode-schema-downgraded"]; ok {
127+
t.Fatalf("expected no custom downgrade marker in outbound schema, got %+v", downgradedSchema)
128+
}
129+
74130
withSessionAsset, err := BuildRequest(context.Background(), testCfg("https://api.example.com/v1", "gpt-4.1", "test-key"), providertypes.GenerateRequest{
75131
Messages: []providertypes.Message{
76132
{

internal/provider/openaicompat/chatcompletions/request.go

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ func BuildRequest(ctx context.Context, cfg provider.RuntimeConfig, req providert
6767
Function: FunctionDefinition{
6868
Name: spec.Name,
6969
Description: spec.Description,
70-
Parameters: spec.Schema,
70+
Parameters: normalizeToolSchemaForOpenAI(spec.Schema),
7171
},
7272
}
7373
payload.Tools = append(payload.Tools, def)
@@ -77,6 +77,40 @@ func BuildRequest(ctx context.Context, cfg provider.RuntimeConfig, req providert
7777
return payload, nil
7878
}
7979

80+
// normalizeToolSchemaForOpenAI 归一化工具参数 schema,避免修改调用方原始结构并尽量保持语义。
81+
// 仅在缺失 schema 或明显非法(非 object 顶层)时做最小兼容降级,不再删除顶层组合约束关键字。
82+
func normalizeToolSchemaForOpenAI(schema map[string]any) map[string]any {
83+
normalized := cloneSchemaTopLevel(schema)
84+
if len(normalized) == 0 {
85+
return map[string]any{
86+
"type": "object",
87+
"properties": map[string]any{},
88+
}
89+
}
90+
91+
typeName, _ := normalized["type"].(string)
92+
if strings.TrimSpace(strings.ToLower(typeName)) != "object" {
93+
normalized["type"] = "object"
94+
}
95+
96+
if _, ok := normalized["properties"].(map[string]any); !ok {
97+
normalized["properties"] = map[string]any{}
98+
}
99+
return normalized
100+
}
101+
102+
// cloneSchemaTopLevel 复制 schema 顶层 map,避免归一化阶段修改调用方原始结构。
103+
func cloneSchemaTopLevel(schema map[string]any) map[string]any {
104+
if len(schema) == 0 {
105+
return map[string]any{}
106+
}
107+
cloned := make(map[string]any, len(schema))
108+
for key, value := range schema {
109+
cloned[key] = value
110+
}
111+
return cloned
112+
}
113+
80114
// ToOpenAIMessage 将通用 Message 转换为 OpenAI 协议消息格式。
81115
func ToOpenAIMessage(ctx context.Context, message providertypes.Message, assetReader providertypes.SessionAssetReader) (Message, error) {
82116
msg, _, err := toOpenAIMessageWithBudget(ctx, message, assetReader, maxSessionAssetsTotalBytes)

internal/runtime/events.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,28 @@ type TodoEventPayload struct {
9292
Reason string `json:"reason,omitempty"`
9393
}
9494

95+
// InputNormalizedPayload 描述输入归一化完成后的摘要信息。
96+
type InputNormalizedPayload struct {
97+
TextLength int `json:"text_length"`
98+
ImageCount int `json:"image_count"`
99+
}
100+
101+
// AssetSavedPayload 描述单个附件成功保存后的结果。
102+
type AssetSavedPayload struct {
103+
Index int `json:"index"`
104+
Path string `json:"path,omitempty"`
105+
AssetID string `json:"asset_id"`
106+
MimeType string `json:"mime_type,omitempty"`
107+
Size int64 `json:"size,omitempty"`
108+
}
109+
110+
// AssetSaveFailedPayload 描述单个附件保存失败的结构化信息。
111+
type AssetSaveFailedPayload struct {
112+
Index int `json:"index"`
113+
Path string `json:"path,omitempty"`
114+
Message string `json:"message"`
115+
}
116+
95117
const (
96118
// EventUserMessage 表示用户消息已写入会话。
97119
EventUserMessage EventType = "user_message"
@@ -143,6 +165,12 @@ const (
143165
EventTodoConflict EventType = "todo_conflict"
144166
// EventTodoSummaryInjected 表示本轮上下文注入了 Todo 摘要。
145167
EventTodoSummaryInjected EventType = "todo_summary_injected"
168+
// EventInputNormalized 表示用户输入已完成归一化。
169+
EventInputNormalized EventType = "input_normalized"
170+
// EventAssetSaved 表示本轮用户输入附件已完成持久化。
171+
EventAssetSaved EventType = "asset_saved"
172+
// EventAssetSaveFailed 表示本轮用户输入附件持久化失败。
173+
EventAssetSaveFailed EventType = "asset_save_failed"
146174
)
147175

148176
// TokenUsagePayload 承载单轮 token 用量统计。

internal/runtime/input_prepare.go

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
package runtime
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"strings"
8+
"time"
9+
10+
agentsession "neo-code/internal/session"
11+
)
12+
13+
const prepareEventEmitTimeout = 200 * time.Millisecond
14+
15+
// NewSessionInputPreparer 创建基于 session 子层实现的输入归一化适配器。
16+
func NewSessionInputPreparer(store agentsession.Store, assetStore agentsession.AssetStore) UserInputPreparer {
17+
return sessionInputPreparer{
18+
preparer: agentsession.NewInputPreparer(store, assetStore),
19+
}
20+
}
21+
22+
// PrepareUserInput 负责在运行前执行输入归一化编排,并发出最小可观测事件。
23+
// Submit 作为运行时提交入口,统一串联输入归一化与执行,避免上层手动编排两段调用。
24+
func (s *Service) Submit(ctx context.Context, input PrepareInput) error {
25+
prepared, err := s.PrepareUserInput(ctx, input)
26+
if err != nil {
27+
return err
28+
}
29+
return s.Run(ctx, prepared)
30+
}
31+
32+
func (s *Service) PrepareUserInput(ctx context.Context, input PrepareInput) (UserInput, error) {
33+
if err := ctx.Err(); err != nil {
34+
return UserInput{}, err
35+
}
36+
if s == nil {
37+
return UserInput{}, errors.New("runtime: service is nil")
38+
}
39+
if s.userInputPreparer == nil {
40+
err := errors.New("runtime: user input preparer is not configured")
41+
_ = s.emitPrepareFailure(ctx, input, err)
42+
return UserInput{}, err
43+
}
44+
45+
defaultWorkdir := ""
46+
if s.configManager != nil {
47+
defaultWorkdir = strings.TrimSpace(s.configManager.Get().Workdir)
48+
}
49+
50+
prepared, err := s.userInputPreparer.Prepare(ctx, input, defaultWorkdir)
51+
if err != nil {
52+
_ = s.emitPrepareFailure(ctx, input, err)
53+
return UserInput{}, err
54+
}
55+
56+
runID := strings.TrimSpace(input.RunID)
57+
_ = s.emitPrepareEvent(ctx, EventInputNormalized, runID, prepared.UserInput.SessionID, InputNormalizedPayload{
58+
TextLength: len([]rune(strings.TrimSpace(input.Text))),
59+
ImageCount: len(input.Images),
60+
})
61+
for index, asset := range prepared.SavedAssets {
62+
path := ""
63+
if index >= 0 && index < len(input.Images) {
64+
path = strings.TrimSpace(input.Images[index].Path)
65+
}
66+
_ = s.emitPrepareEvent(ctx, EventAssetSaved, runID, prepared.UserInput.SessionID, AssetSavedPayload{
67+
Index: index,
68+
Path: path,
69+
AssetID: asset.ID,
70+
MimeType: asset.MimeType,
71+
Size: asset.Size,
72+
})
73+
}
74+
75+
return prepared.UserInput, nil
76+
}
77+
78+
// emitPrepareFailure 统一发送输入归一化阶段的失败事件,避免前置副作用变成黑箱。
79+
func (s *Service) emitPrepareFailure(ctx context.Context, input PrepareInput, err error) error {
80+
if s == nil {
81+
return nil
82+
}
83+
84+
runID := strings.TrimSpace(input.RunID)
85+
sessionID := strings.TrimSpace(input.SessionID)
86+
87+
var saveErr *agentsession.AssetSaveError
88+
if errors.As(err, &saveErr) {
89+
if session := strings.TrimSpace(saveErr.SessionID); session != "" {
90+
sessionID = session
91+
}
92+
return s.emitPrepareEvent(ctx, EventAssetSaveFailed, runID, sessionID, AssetSaveFailedPayload{
93+
Index: saveErr.Index,
94+
Path: strings.TrimSpace(saveErr.Path),
95+
Message: strings.TrimSpace(saveErr.Error()),
96+
})
97+
}
98+
return s.emitPrepareEvent(ctx, EventError, runID, sessionID, strings.TrimSpace(err.Error()))
99+
}
100+
101+
// emitPrepareEvent 在输入归一化阶段使用限时上下文发事件,避免通道拥塞导致提交链路卡死。
102+
func (s *Service) emitPrepareEvent(ctx context.Context, kind EventType, runID string, sessionID string, payload any) error {
103+
emitCtx := ctx
104+
cancel := func() {}
105+
if _, hasDeadline := emitCtx.Deadline(); !hasDeadline {
106+
emitCtx, cancel = context.WithTimeout(emitCtx, prepareEventEmitTimeout)
107+
}
108+
defer cancel()
109+
110+
if err := s.emit(emitCtx, kind, runID, sessionID, payload); err != nil {
111+
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
112+
return nil
113+
}
114+
return err
115+
}
116+
return nil
117+
}
118+
119+
type sessionInputPreparer struct {
120+
preparer *agentsession.InputPreparer
121+
}
122+
123+
// Prepare 将 runtime 输入 DTO 映射到 session 子层并返回标准 UserInput 结果。
124+
func (p sessionInputPreparer) Prepare(
125+
ctx context.Context,
126+
input PrepareInput,
127+
defaultWorkdir string,
128+
) (PreparedInputResult, error) {
129+
if p.preparer == nil {
130+
return PreparedInputResult{}, errors.New("runtime: session input preparer is nil")
131+
}
132+
133+
sessionImages := make([]agentsession.PrepareImageInput, 0, len(input.Images))
134+
for _, image := range input.Images {
135+
sessionImages = append(sessionImages, agentsession.PrepareImageInput{
136+
Path: strings.TrimSpace(image.Path),
137+
MimeType: strings.TrimSpace(image.MimeType),
138+
})
139+
}
140+
141+
prepared, err := p.preparer.Prepare(ctx, agentsession.PrepareInput{
142+
SessionID: strings.TrimSpace(input.SessionID),
143+
Text: input.Text,
144+
Images: sessionImages,
145+
DefaultWorkdir: strings.TrimSpace(defaultWorkdir),
146+
RequestedWorkdir: strings.TrimSpace(input.Workdir),
147+
})
148+
if err != nil {
149+
return PreparedInputResult{}, err
150+
}
151+
152+
if len(prepared.Parts) == 0 {
153+
return PreparedInputResult{}, fmt.Errorf("runtime: prepared parts is empty")
154+
}
155+
156+
return PreparedInputResult{
157+
UserInput: UserInput{
158+
SessionID: strings.TrimSpace(prepared.SessionID),
159+
RunID: strings.TrimSpace(input.RunID),
160+
Parts: prepared.Parts,
161+
Workdir: strings.TrimSpace(prepared.Workdir),
162+
},
163+
SavedAssets: append([]agentsession.AssetMeta(nil), prepared.SavedAssets...),
164+
}, nil
165+
}

0 commit comments

Comments
 (0)