Skip to content

Commit 93a0713

Browse files
fix(mcp): require recovery token for ambiguous project choice
1 parent efa6e79 commit 93a0713

3 files changed

Lines changed: 504 additions & 21 deletions

File tree

internal/mcp/activity.go

Lines changed: 78 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
package mcp
22

33
import (
4+
"crypto/rand"
5+
"encoding/hex"
46
"fmt"
7+
"path/filepath"
8+
"slices"
59
"sync"
610
"time"
711
)
812

13+
const ambiguousProjectRecoveryTTL = 5 * time.Minute
14+
915
// SessionActivity tracks tool call activity for save reminders and activity scores.
1016
type SessionActivity struct {
1117
mu sync.Mutex
@@ -15,18 +21,26 @@ type SessionActivity struct {
1521
}
1622

1723
type sessionState struct {
18-
lastSaveAt time.Time
19-
toolCallCount int
20-
saveCount int
21-
startedAt time.Time
22-
currentPrompt *promptContext
24+
lastSaveAt time.Time
25+
toolCallCount int
26+
saveCount int
27+
startedAt time.Time
28+
currentPrompt *promptContext
29+
recoveryTokens map[string]*ambiguousProjectRecovery
2330
}
2431

2532
type promptContext struct {
2633
project string
2734
content string
2835
}
2936

37+
type ambiguousProjectRecovery struct {
38+
availableProjects []string
39+
contextPath string
40+
expiresAt time.Time
41+
selectedProject string
42+
}
43+
3044
// NewSessionActivity creates a new activity tracker with the given nudge threshold.
3145
func NewSessionActivity(nudgeAfter time.Duration) *SessionActivity {
3246
return &SessionActivity{
@@ -36,6 +50,14 @@ func NewSessionActivity(nudgeAfter time.Duration) *SessionActivity {
3650
}
3751
}
3852

53+
func generateRecoveryToken() string {
54+
var b [16]byte
55+
if _, err := rand.Read(b[:]); err != nil {
56+
return fmt.Sprintf("fallback-%d", time.Now().UnixNano())
57+
}
58+
return hex.EncodeToString(b[:])
59+
}
60+
3961
func (a *SessionActivity) getOrCreate(sessionID string) *sessionState {
4062
s, ok := a.sessions[sessionID]
4163
if !ok {
@@ -60,6 +82,57 @@ func (a *SessionActivity) ClearSession(sessionID string) {
6082
delete(a.sessions, sessionID)
6183
}
6284

85+
func (a *SessionActivity) IssueAmbiguousProjectRecoveryToken(sessionID string, availableProjects []string, contextPath string) string {
86+
if a == nil {
87+
return ""
88+
}
89+
a.mu.Lock()
90+
defer a.mu.Unlock()
91+
s := a.getOrCreate(sessionID)
92+
if s.recoveryTokens == nil {
93+
s.recoveryTokens = make(map[string]*ambiguousProjectRecovery)
94+
}
95+
token := generateRecoveryToken()
96+
projects := append([]string(nil), availableProjects...)
97+
slices.Sort(projects)
98+
s.recoveryTokens[token] = &ambiguousProjectRecovery{
99+
availableProjects: projects,
100+
contextPath: filepath.Clean(contextPath),
101+
expiresAt: a.now().Add(ambiguousProjectRecoveryTTL),
102+
}
103+
return token
104+
}
105+
106+
func (a *SessionActivity) ValidateAmbiguousProjectRecoveryToken(sessionID, token, selectedProject string, availableProjects []string, contextPath string) bool {
107+
if a == nil || token == "" || selectedProject == "" {
108+
return false
109+
}
110+
a.mu.Lock()
111+
defer a.mu.Unlock()
112+
s, ok := a.sessions[sessionID]
113+
if !ok || s.recoveryTokens == nil {
114+
return false
115+
}
116+
recovery, ok := s.recoveryTokens[token]
117+
if !ok {
118+
return false
119+
}
120+
if !recovery.expiresAt.IsZero() && !a.now().Before(recovery.expiresAt) {
121+
delete(s.recoveryTokens, token)
122+
return false
123+
}
124+
projects := append([]string(nil), availableProjects...)
125+
slices.Sort(projects)
126+
if !slices.Equal(recovery.availableProjects, projects) || recovery.contextPath != filepath.Clean(contextPath) {
127+
return false
128+
}
129+
if recovery.selectedProject == "" {
130+
recovery.selectedProject = selectedProject
131+
return true
132+
}
133+
return recovery.selectedProject == selectedProject
134+
}
135+
63136
// RecordSave increments the save counter and updates lastSaveAt.
64137
func (a *SessionActivity) RecordSave(sessionID string) {
65138
a.mu.Lock()

internal/mcp/mcp.go

Lines changed: 128 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,9 @@ Examples:
339339
mcp.WithString("project_choice_reason",
340340
mcp.Description("Must be user_selected_after_ambiguous_project, and only after the user explicitly chose one of available_projects from an ambiguous_project error."),
341341
),
342+
mcp.WithString("recovery_token",
343+
mcp.Description("Short-lived token returned by an ambiguous_project error. Required with project_choice_reason=user_selected_after_ambiguous_project."),
344+
),
342345
mcp.WithBoolean("capture_prompt",
343346
mcp.Description("Automatically capture the current user prompt when available (default: true). Set false for SDD artifacts or automated saves."),
344347
),
@@ -453,6 +456,9 @@ Examples:
453456
mcp.WithString("project_choice_reason",
454457
mcp.Description("Must be user_selected_after_ambiguous_project, and only after the user explicitly chose one of available_projects from an ambiguous_project error."),
455458
),
459+
mcp.WithString("recovery_token",
460+
mcp.Description("Short-lived token returned by an ambiguous_project error. Required with project_choice_reason=user_selected_after_ambiguous_project."),
461+
),
456462
),
457463
queuedWriteHandler(writeQueue, handleSavePrompt(s, cfg, activity)),
458464
)
@@ -1022,13 +1028,24 @@ func handleSave(s *store.Store, cfg MCPConfig, activity *SessionActivity) server
10221028
projectChoice, _ := req.GetArguments()["project"].(string)
10231029
_, explicitProjectProvided := req.GetArguments()["project"]
10241030
projectChoiceReason, _ := req.GetArguments()["project_choice_reason"].(string)
1031+
recoveryToken, _ := req.GetArguments()["recovery_token"].(string)
10251032
capturePrompt := boolArg(req, "capture_prompt", true)
1033+
recoverySessionID := sessionID
1034+
if strings.TrimSpace(recoverySessionID) == "" {
1035+
recoverySessionID = defaultSessionID("")
1036+
}
1037+
validateRecoveryToken := func(res projectpkg.DetectionResult, choice string) (bool, bool) {
1038+
if strings.TrimSpace(recoveryToken) == "" {
1039+
return false, false
1040+
}
1041+
return true, activity.ValidateAmbiguousProjectRecoveryToken(recoverySessionID, recoveryToken, strings.TrimSpace(choice), res.AvailableProjects, res.Path)
1042+
}
10261043

10271044
// Resolve write project using the full MCP precedence: explicit request,
10281045
// existing session association, repo config/directory detection, then cwd fallback.
1029-
detRes, err := resolveSaveWriteProject(s, projectChoice, explicitProjectProvided, projectChoiceReason, sessionID)
1046+
detRes, err := resolveSaveWriteProject(s, projectChoice, explicitProjectProvided, projectChoiceReason, sessionID, validateRecoveryToken)
10301047
if err != nil {
1031-
return writeProjectErrorResult(detRes, err), nil
1048+
return writeProjectErrorResult(activity, recoverySessionID, detRes, err), nil
10321049
}
10331050
project := detRes.Project
10341051

@@ -1271,10 +1288,21 @@ func handleSavePrompt(s *store.Store, cfg MCPConfig, activity *SessionActivity)
12711288
sessionID, _ := req.GetArguments()["session_id"].(string)
12721289
projectChoice, _ := req.GetArguments()["project"].(string)
12731290
projectChoiceReason, _ := req.GetArguments()["project_choice_reason"].(string)
1291+
recoveryToken, _ := req.GetArguments()["recovery_token"].(string)
1292+
recoverySessionID := sessionID
1293+
if strings.TrimSpace(recoverySessionID) == "" {
1294+
recoverySessionID = defaultSessionID("")
1295+
}
1296+
validateRecoveryToken := func(res projectpkg.DetectionResult, choice string) (bool, bool) {
1297+
if strings.TrimSpace(recoveryToken) == "" {
1298+
return false, false
1299+
}
1300+
return true, activity.ValidateAmbiguousProjectRecoveryToken(recoverySessionID, recoveryToken, strings.TrimSpace(choice), res.AvailableProjects, res.Path)
1301+
}
12741302

1275-
detRes, err := resolveWriteProjectWithChoice(projectChoice, projectChoiceReason)
1303+
detRes, err := resolveWriteProjectWithChoice(projectChoice, projectChoiceReason, validateRecoveryToken)
12761304
if err != nil {
1277-
return writeProjectErrorResult(detRes, err), nil
1305+
return writeProjectErrorResult(activity, recoverySessionID, detRes, err), nil
12781306
}
12791307
project, _ := store.NormalizeProject(detRes.Project)
12801308

@@ -1556,7 +1584,7 @@ func handleSessionSummary(s *store.Store, cfg MCPConfig, activity *SessionActivi
15561584
// Auto-detect project from cwd; fail fast on ambiguous (REQ-308, REQ-309)
15571585
detRes, err := resolveWriteProject()
15581586
if err != nil {
1559-
return writeProjectErrorResult(detRes, err), nil
1587+
return writeProjectErrorResult(nil, "", detRes, err), nil
15601588
}
15611589
project, _ := store.NormalizeProject(detRes.Project)
15621590

@@ -1596,7 +1624,7 @@ func handleSessionStart(s *store.Store, cfg MCPConfig, activity *SessionActivity
15961624

15971625
detRes, err := resolveSessionStartProject(resolvedDirectory)
15981626
if err != nil {
1599-
return writeProjectErrorResult(detRes, err), nil
1627+
return writeProjectErrorResult(nil, "", detRes, err), nil
16001628
}
16011629
project, _ := store.NormalizeProject(detRes.Project)
16021630

@@ -1637,7 +1665,7 @@ func handleSessionEnd(s *store.Store, cfg MCPConfig, activity *SessionActivity)
16371665
detRes, err := resolveWriteProject()
16381666
if err != nil {
16391667
if errors.Is(err, projectpkg.ErrInvalidConfig) {
1640-
return writeProjectErrorResult(detRes, err), nil
1668+
return writeProjectErrorResult(nil, "", detRes, err), nil
16411669
}
16421670
// For session end, still complete the operation even if project resolution fails.
16431671
// Use basename fallback.
@@ -1670,7 +1698,7 @@ func handleCapturePassive(s *store.Store, cfg MCPConfig, activity *SessionActivi
16701698

16711699
detRes, err := resolveWriteProject()
16721700
if err != nil {
1673-
return writeProjectErrorResult(detRes, err), nil
1701+
return writeProjectErrorResult(nil, "", detRes, err), nil
16741702
}
16751703
project, _ := store.NormalizeProject(detRes.Project)
16761704

@@ -1914,6 +1942,24 @@ func (e *invalidProjectChoiceError) Error() string {
19141942
return "invalid project choice: " + e.Name
19151943
}
19161944

1945+
type missingRecoveryTokenError struct {
1946+
Name string
1947+
AvailableProjects []string
1948+
}
1949+
1950+
func (e *missingRecoveryTokenError) Error() string {
1951+
return "missing ambiguous project recovery token for project choice: " + e.Name
1952+
}
1953+
1954+
type invalidRecoveryTokenError struct {
1955+
Name string
1956+
AvailableProjects []string
1957+
}
1958+
1959+
func (e *invalidRecoveryTokenError) Error() string {
1960+
return "invalid ambiguous project recovery token for project choice: " + e.Name
1961+
}
1962+
19171963
type invalidExplicitProjectError struct {
19181964
Name string
19191965
Reason string
@@ -1968,9 +2014,11 @@ func resolveWriteProject() (projectpkg.DetectionResult, error) {
19682014
return res, nil
19692015
}
19702016

2017+
type ambiguousRecoveryTokenValidator func(projectpkg.DetectionResult, string) (provided bool, valid bool)
2018+
19712019
// resolveWriteProjectWithChoice preserves normal write resolution authority and
19722020
// only uses an explicit project choice as a recovery path from ErrAmbiguousProject.
1973-
func resolveWriteProjectWithChoice(projectChoice, reason string) (projectpkg.DetectionResult, error) {
2021+
func resolveWriteProjectWithChoice(projectChoice, reason string, validateToken ambiguousRecoveryTokenValidator) (projectpkg.DetectionResult, error) {
19742022
res, err := resolveWriteProject()
19752023
if err == nil {
19762024
// Non-ambiguous config/git/autodetect remains authoritative. Ignore any
@@ -1999,6 +2047,22 @@ func resolveWriteProjectWithChoice(projectChoice, reason string) (projectpkg.Det
19992047
CollidingProjects: colliding,
20002048
}
20012049
}
2050+
provided, valid := false, false
2051+
if validateToken != nil {
2052+
provided, valid = validateToken(res, choice)
2053+
}
2054+
if !provided {
2055+
return res, &missingRecoveryTokenError{
2056+
Name: choice,
2057+
AvailableProjects: res.AvailableProjects,
2058+
}
2059+
}
2060+
if !valid {
2061+
return res, &invalidRecoveryTokenError{
2062+
Name: choice,
2063+
AvailableProjects: res.AvailableProjects,
2064+
}
2065+
}
20022066

20032067
res.Project = choice
20042068
res.Source = projectpkg.SourceUserSelectedAfterAmbiguousProject
@@ -2007,7 +2071,7 @@ func resolveWriteProjectWithChoice(projectChoice, reason string) (projectpkg.Det
20072071
return res, nil
20082072
}
20092073

2010-
func resolveSaveWriteProject(s *store.Store, projectChoice string, explicitProjectProvided bool, reason, sessionID string) (projectpkg.DetectionResult, error) {
2074+
func resolveSaveWriteProject(s *store.Store, projectChoice string, explicitProjectProvided bool, reason, sessionID string, validateToken ambiguousRecoveryTokenValidator) (projectpkg.DetectionResult, error) {
20112075
trimmedSessionID := strings.TrimSpace(sessionID)
20122076
trimmedProjectChoice := strings.TrimSpace(projectChoice)
20132077
trimmedReason := strings.TrimSpace(reason)
@@ -2096,7 +2160,7 @@ func resolveSaveWriteProject(s *store.Store, projectChoice string, explicitProje
20962160
}
20972161
if errors.Is(cwdErr, projectpkg.ErrAmbiguousProject) {
20982162
if trimmedReason == projectpkg.SourceUserSelectedAfterAmbiguousProject {
2099-
return resolveWriteProjectWithChoice(projectChoice, reason)
2163+
return resolveWriteProjectWithChoice(projectChoice, reason, validateToken)
21002164
}
21012165
return cwdRes, cwdErr
21022166
}
@@ -2124,7 +2188,7 @@ func resolveSaveWriteProject(s *store.Store, projectChoice string, explicitProje
21242188
}
21252189

21262190
if trimmedReason == projectpkg.SourceUserSelectedAfterAmbiguousProject && trimmedProjectChoice != "" {
2127-
res, err := resolveWriteProjectWithChoice(projectChoice, reason)
2191+
res, err := resolveWriteProjectWithChoice(projectChoice, reason, validateToken)
21282192
if err != nil {
21292193
return res, err
21302194
}
@@ -2379,7 +2443,7 @@ func respondWithProject(res projectpkg.DetectionResult, text string, extra map[s
23792443
return mcp.NewToolResultText(string(out))
23802444
}
23812445

2382-
func writeProjectErrorResult(res projectpkg.DetectionResult, err error) *mcp.CallToolResult {
2446+
func writeProjectErrorResult(activity *SessionActivity, sessionID string, res projectpkg.DetectionResult, err error) *mcp.CallToolResult {
23832447
code := "ambiguous_project"
23842448
if errors.Is(err, projectpkg.ErrInvalidConfig) {
23852449
code = "invalid_project_config"
@@ -2397,6 +2461,20 @@ func writeProjectErrorResult(res projectpkg.DetectionResult, err error) *mcp.Cal
23972461
choiceErr.AvailableProjects,
23982462
)
23992463
}
2464+
var missingTokenErr *missingRecoveryTokenError
2465+
if errors.As(err, &missingTokenErr) {
2466+
return errorWithMeta("missing_recovery_token",
2467+
fmt.Sprintf("project_choice_reason=user_selected_after_ambiguous_project for %q requires the recovery_token from the ambiguous_project error", missingTokenErr.Name),
2468+
missingTokenErr.AvailableProjects,
2469+
)
2470+
}
2471+
var invalidTokenErr *invalidRecoveryTokenError
2472+
if errors.As(err, &invalidTokenErr) {
2473+
return errorWithMeta("invalid_recovery_token",
2474+
fmt.Sprintf("recovery_token is invalid, stale, or not valid for selected project %q", invalidTokenErr.Name),
2475+
invalidTokenErr.AvailableProjects,
2476+
)
2477+
}
24002478
var explicitErr *invalidExplicitProjectError
24012479
if errors.As(err, &explicitErr) {
24022480
return errorWithMeta("invalid_project",
@@ -2435,7 +2513,39 @@ func writeProjectErrorResult(res projectpkg.DetectionResult, err error) *mcp.Cal
24352513
res.AvailableProjects,
24362514
)
24372515
}
2438-
return errorWithMeta(code, fmt.Sprintf("Cannot determine project: %s", err), res.AvailableProjects)
2516+
result := errorWithMeta(code, fmt.Sprintf("Cannot determine project: %s", err), res.AvailableProjects)
2517+
if code == "ambiguous_project" && activity != nil {
2518+
if strings.TrimSpace(sessionID) == "" {
2519+
sessionID = defaultSessionID("")
2520+
}
2521+
addErrorMetadata(result, map[string]any{
2522+
"recovery_token": activity.IssueAmbiguousProjectRecoveryToken(sessionID, res.AvailableProjects, res.Path),
2523+
"token_ttl_seconds": int(ambiguousProjectRecoveryTTL.Seconds()),
2524+
})
2525+
}
2526+
return result
2527+
}
2528+
2529+
func addErrorMetadata(result *mcp.CallToolResult, metadata map[string]any) {
2530+
if result == nil || len(result.Content) == 0 || len(metadata) == 0 {
2531+
return
2532+
}
2533+
text, ok := mcp.AsTextContent(result.Content[0])
2534+
if !ok {
2535+
return
2536+
}
2537+
var envelope map[string]any
2538+
if err := json.Unmarshal([]byte(text.Text), &envelope); err != nil {
2539+
return
2540+
}
2541+
for k, v := range metadata {
2542+
envelope[k] = v
2543+
}
2544+
out, err := jsonMarshal(envelope)
2545+
if err != nil {
2546+
return
2547+
}
2548+
result.Content[0] = mcp.NewTextContent(string(out))
24392549
}
24402550

24412551
// errorWithMeta returns a structured tool error result with error_code,
@@ -2451,6 +2561,10 @@ func errorWithMeta(code, msg string, availableProjects []string) *mcp.CallToolRe
24512561
envelope["hint"] = "Ask the user to choose one of available_projects, then retry mem_save or mem_save_prompt with project and project_choice_reason=user_selected_after_ambiguous_project; alternatively cd into the target repo or add repo .engram/config.json."
24522562
case "invalid_project_choice":
24532563
envelope["hint"] = "Use exactly one of available_projects after asking the user, or cd into the target repo, or add repo .engram/config.json."
2564+
case "missing_recovery_token":
2565+
envelope["hint"] = "Retry with the recovery_token returned by the ambiguous_project error after the user selects one available_projects value."
2566+
case "invalid_recovery_token":
2567+
envelope["hint"] = "Request a fresh ambiguous_project recovery_token and retry with the same session, cwd context, and selected available_projects value before it expires."
24542568
case "unknown_project":
24552569
envelope["hint"] = "Use one of the available_projects values, or omit project to auto-detect."
24562570
case "invalid_project_config":

0 commit comments

Comments
 (0)