Skip to content

Commit ae49927

Browse files
committed
feat(runtime): 会话加载时自动修复未闭合 tool_calls 尾部并采用安全裁剪
新增 transcript 工具函数:RepairIncompleteToolCallTail 检测并截断未闭合的 assistant tool_calls 尾部,TrimMessagesToLimitPreservingToolSpans / TrimPrefixCountPreservingToolSpans 保证裁剪不会从 tool span 中间切断。 runtime 层 LoadSession 统一走修复路径,所有内部调用方从直接调用 sessionStore.LoadSession 改为走 Service.LoadSession,确保任何加载入口 都会修复残缺消息链。SQLite 持久层的消息裁剪同步改为安全版本。
1 parent ee99d33 commit ae49927

15 files changed

Lines changed: 435 additions & 24 deletions

internal/runtime/checkpoint_restore.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ func (s *Service) UndoRestoreCheckpoint(ctx context.Context, sessionID string) (
218218
// guardWritten=false 时若 fallbackRef 非空,则用它作为 CodeCheckpointRef 以保证 undo 可走代码恢复路径。
219219
// fallbackRef 应为完整的 "peredit:<id>" 格式引用。
220220
func (s *Service) createGuardCheckpoint(ctx context.Context, sessionID, runID, guardID string, guardWritten bool, fallbackRef string) (agentsession.CheckpointRecord, error) {
221-
session, err := s.sessionStore.LoadSession(ctx, sessionID)
221+
session, err := s.LoadSession(ctx, sessionID)
222222
if err != nil {
223223
return agentsession.CheckpointRecord{}, fmt.Errorf("checkpoint: load session for guard: %w", err)
224224
}

internal/runtime/compact.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ func (s *Service) Compact(ctx context.Context, input CompactInput) (CompactResul
8282
if err != nil {
8383
return CompactResult{}, err
8484
}
85-
session, err := s.sessionStore.LoadSession(ctx, input.SessionID)
85+
session, err := s.LoadSession(ctx, input.SessionID)
8686
if err != nil {
8787
return CompactResult{}, err
8888
}

internal/runtime/plan_approval.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ func (s *Service) ApproveCurrentPlan(ctx context.Context, input ApproveCurrentPl
1919
releaseLock := s.bindSessionLock(sessionID)
2020
defer releaseLock()
2121

22-
session, err := s.sessionStore.LoadSession(ctx, sessionID)
22+
session, err := s.LoadSession(ctx, sessionID)
2323
if err != nil {
2424
return err
2525
}

internal/runtime/runtime.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ import (
1313
"neo-code/internal/config"
1414
agentcontext "neo-code/internal/context"
1515
contextcompact "neo-code/internal/context/compact"
16-
"neo-code/internal/repository"
1716
"neo-code/internal/provider"
1817
"neo-code/internal/provider/builtin"
1918
providertypes "neo-code/internal/provider/types"
19+
"neo-code/internal/repository"
2020
"neo-code/internal/runtime/approval"
2121
runtimehooks "neo-code/internal/runtime/hooks"
2222
"neo-code/internal/security"
@@ -422,6 +422,9 @@ func (s *Service) LoadSession(ctx context.Context, id string) (agentsession.Sess
422422
if err != nil {
423423
return agentsession.Session{}, err
424424
}
425+
if err := s.repairSessionTranscriptIfNeeded(ctx, &session); err != nil {
426+
return agentsession.Session{}, err
427+
}
425428
return session, nil
426429
}
427430

@@ -443,7 +446,7 @@ func (s *Service) CreateSession(ctx context.Context, id string) (agentsession.Se
443446
return agentsession.Session{}, err
444447
}
445448

446-
existing, err := s.sessionStore.LoadSession(ctx, sessionID)
449+
existing, err := s.LoadSession(ctx, sessionID)
447450
if err == nil {
448451
return existing, nil
449452
}
@@ -459,7 +462,7 @@ func (s *Service) CreateSession(ctx context.Context, id string) (agentsession.Se
459462
return created, nil
460463
}
461464
if isRuntimeSessionAlreadyExistsError(createErr) {
462-
return s.sessionStore.LoadSession(ctx, sessionID)
465+
return s.LoadSession(ctx, sessionID)
463466
}
464467
return agentsession.Session{}, createErr
465468
}

internal/runtime/runtime_snapshot.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ func (s *Service) GetRuntimeSnapshot(ctx context.Context, sessionID string) (Run
197197
return cached, nil
198198
}
199199

200-
session, err := s.sessionStore.LoadSession(ctx, normalizedSessionID)
200+
session, err := s.LoadSession(ctx, normalizedSessionID)
201201
if err != nil {
202202
return RuntimeSnapshot{}, err
203203
}

internal/runtime/runtime_test.go

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ import (
1515
"neo-code/internal/config"
1616
agentcontext "neo-code/internal/context"
1717
contextcompact "neo-code/internal/context/compact"
18-
"neo-code/internal/repository"
1918
"neo-code/internal/provider"
2019
providertypes "neo-code/internal/provider/types"
20+
"neo-code/internal/repository"
2121
approvalflow "neo-code/internal/runtime/approval"
2222
"neo-code/internal/runtime/controlplane"
2323
"neo-code/internal/runtime/streaming"
@@ -3594,6 +3594,108 @@ func TestServiceListSessionsSkipsPromotionWhenDerivedTitleInvalid(t *testing.T)
35943594
}
35953595
}
35963596

3597+
func TestServiceLoadSessionRepairsIncompleteToolCallTail(t *testing.T) {
3598+
manager := newRuntimeConfigManager(t)
3599+
store := newMemoryStore()
3600+
service := NewWithFactory(manager, tools.NewRegistry(), store, nil, nil)
3601+
3602+
session := agentsession.New("Repair Me")
3603+
session.Messages = []providertypes.Message{
3604+
{Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("before")}},
3605+
{
3606+
Role: providertypes.RoleAssistant,
3607+
ToolCalls: []providertypes.ToolCall{
3608+
{ID: "call-1", Name: "filesystem_read_file", Arguments: `{"path":"README.md"}`},
3609+
{ID: "call-2", Name: "bash", Arguments: `{"command":"echo hi"}`},
3610+
},
3611+
},
3612+
{
3613+
Role: providertypes.RoleTool,
3614+
ToolCallID: "call-1",
3615+
Parts: []providertypes.ContentPart{providertypes.NewTextPart("README")},
3616+
},
3617+
}
3618+
store.sessions[session.ID] = cloneSession(session)
3619+
3620+
loaded, err := service.LoadSession(context.Background(), session.ID)
3621+
if err != nil {
3622+
t.Fatalf("LoadSession() error = %v", err)
3623+
}
3624+
if len(loaded.Messages) != 1 {
3625+
t.Fatalf("len(loaded.Messages) = %d, want 1", len(loaded.Messages))
3626+
}
3627+
if got := renderPartsForTest(loaded.Messages[0].Parts); got != "before" {
3628+
t.Fatalf("loaded preserved message = %q, want %q", got, "before")
3629+
}
3630+
3631+
persisted, err := store.LoadSession(context.Background(), session.ID)
3632+
if err != nil {
3633+
t.Fatalf("store.LoadSession() error = %v", err)
3634+
}
3635+
if len(persisted.Messages) != 1 {
3636+
t.Fatalf("len(persisted.Messages) = %d, want 1", len(persisted.Messages))
3637+
}
3638+
}
3639+
3640+
func TestServiceRunRepairsIncompleteToolCallTailBeforeBuildingContext(t *testing.T) {
3641+
manager := newRuntimeConfigManager(t)
3642+
store := newMemoryStore()
3643+
builder := &stubContextBuilder{}
3644+
scripted := &scriptedProvider{
3645+
responses: []scriptedResponse{
3646+
{
3647+
Message: providertypes.Message{
3648+
Role: providertypes.RoleAssistant,
3649+
Parts: []providertypes.ContentPart{providertypes.NewTextPart("done")},
3650+
},
3651+
FinishReason: "stop",
3652+
},
3653+
},
3654+
}
3655+
service := NewWithFactory(manager, tools.NewRegistry(), store, &scriptedProviderFactory{provider: scripted}, builder)
3656+
3657+
session := agentsession.New("Repair Before Run")
3658+
session.Messages = []providertypes.Message{
3659+
{Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("before")}},
3660+
{
3661+
Role: providertypes.RoleAssistant,
3662+
ToolCalls: []providertypes.ToolCall{
3663+
{ID: "call-1", Name: "filesystem_read_file", Arguments: `{"path":"README.md"}`},
3664+
},
3665+
},
3666+
}
3667+
store.sessions[session.ID] = cloneSession(session)
3668+
3669+
if err := service.Run(context.Background(), UserInput{
3670+
SessionID: session.ID,
3671+
RunID: "run-repair-incomplete-tool-tail",
3672+
Parts: []providertypes.ContentPart{providertypes.NewTextPart("continue")},
3673+
}); err != nil {
3674+
t.Fatalf("Run() error = %v", err)
3675+
}
3676+
3677+
if len(builder.lastInput.Messages) != 2 {
3678+
t.Fatalf("len(builder.lastInput.Messages) = %d, want 2", len(builder.lastInput.Messages))
3679+
}
3680+
if builder.lastInput.Messages[0].Role != providertypes.RoleUser || renderPartsForTest(builder.lastInput.Messages[0].Parts) != "before" {
3681+
t.Fatalf("unexpected repaired history in builder input: %+v", builder.lastInput.Messages)
3682+
}
3683+
if builder.lastInput.Messages[1].Role != providertypes.RoleUser || renderPartsForTest(builder.lastInput.Messages[1].Parts) != "continue" {
3684+
t.Fatalf("expected latest user input in builder messages, got %+v", builder.lastInput.Messages)
3685+
}
3686+
3687+
persisted, err := store.LoadSession(context.Background(), session.ID)
3688+
if err != nil {
3689+
t.Fatalf("store.LoadSession() error = %v", err)
3690+
}
3691+
if len(persisted.Messages) < 2 {
3692+
t.Fatalf("expected repaired transcript plus new turn, got %+v", persisted.Messages)
3693+
}
3694+
if len(persisted.Messages[0].ToolCalls) != 0 {
3695+
t.Fatalf("expected repaired transcript to drop dangling tool_calls, got %+v", persisted.Messages[0])
3696+
}
3697+
}
3698+
35973699
func TestRuntimeSessionTitlePromotionHelpers(t *testing.T) {
35983700
t.Parallel()
35993701

internal/runtime/session_scheduler.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ func (s *Service) loadOrCreateSession(
3434
return session, nil
3535
}
3636

37-
session, err := s.sessionStore.LoadSession(ctx, sessionID)
37+
session, err := s.LoadSession(ctx, sessionID)
3838
if err != nil {
3939
return agentsession.Session{}, err
4040
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package runtime
2+
3+
import (
4+
"context"
5+
"time"
6+
7+
agentsession "neo-code/internal/session"
8+
)
9+
10+
// repairSessionTranscriptIfNeeded 检查会话 transcript 是否存在未闭合的 tool_calls 尾部,
11+
// 若存在则截断残缺尾巴并原子回写,避免后续继续对话时向 provider 发送非法消息链。
12+
func (s *Service) repairSessionTranscriptIfNeeded(ctx context.Context, session *agentsession.Session) error {
13+
if s == nil || session == nil {
14+
return nil
15+
}
16+
17+
repairedMessages, repaired := agentsession.RepairIncompleteToolCallTail(session.Messages)
18+
if !repaired {
19+
return nil
20+
}
21+
22+
session.Messages = repairedMessages
23+
session.UpdatedAt = time.Now()
24+
return s.sessionStore.ReplaceTranscript(ctx, replaceTranscriptInputFromSession(*session))
25+
}

internal/runtime/skills.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ func (s *Service) ListSessionSkills(ctx context.Context, sessionID string) ([]Se
9494
return nil, errors.New("runtime: session id is empty")
9595
}
9696

97-
session, err := s.sessionStore.LoadSession(ctx, sessionID)
97+
session, err := s.LoadSession(ctx, sessionID)
9898
if err != nil {
9999
return nil, err
100100
}
@@ -142,7 +142,7 @@ func (s *Service) ListAvailableSkills(ctx context.Context, sessionID string) ([]
142142
workspace := ""
143143
activeSet := map[string]struct{}{}
144144
if normalizedSessionID != "" {
145-
session, err := s.sessionStore.LoadSession(ctx, normalizedSessionID)
145+
session, err := s.LoadSession(ctx, normalizedSessionID)
146146
if err != nil {
147147
return nil, err
148148
}
@@ -321,7 +321,7 @@ func (s *Service) mutateSessionSkills(
321321
releaseLockRef()
322322
}()
323323

324-
session, err := s.sessionStore.LoadSession(ctx, sessionID)
324+
session, err := s.LoadSession(ctx, sessionID)
325325
if err != nil {
326326
return agentsession.Session{}, false, err
327327
}

internal/runtime/system_tool.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ func (s *Service) ExecuteSystemTool(ctx context.Context, input SystemToolInput)
4545
sessionMu, releaseLockRef := s.acquireSessionLock(sessionID)
4646
sessionMu.RLock()
4747

48-
session, err := s.sessionStore.LoadSession(ctx, sessionID)
48+
session, err := s.LoadSession(ctx, sessionID)
4949
if err != nil {
5050
sessionMu.RUnlock()
5151
releaseLockRef()

0 commit comments

Comments
 (0)