Skip to content

Commit b094052

Browse files
fix(mcp): trust process project override (#378)
1 parent 436c03f commit b094052

4 files changed

Lines changed: 257 additions & 77 deletions

File tree

cmd/engram/main.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -830,14 +830,28 @@ func tryStartAutosync(ctx context.Context, s *store.Store, cfg store.Config) (au
830830
}
831831

832832
func cmdMCP(cfg store.Config) {
833-
// Parse --tools flag. Project is always auto-detected from cwd at call time (JR2-4).
834833
toolsFilter := ""
834+
projectOverride := strings.TrimSpace(os.Getenv("ENGRAM_PROJECT"))
835835
for i := 2; i < len(os.Args); i++ {
836836
if strings.HasPrefix(os.Args[i], "--tools=") {
837837
toolsFilter = strings.TrimPrefix(os.Args[i], "--tools=")
838838
} else if os.Args[i] == "--tools" && i+1 < len(os.Args) {
839839
toolsFilter = os.Args[i+1]
840840
i++
841+
} else if strings.HasPrefix(os.Args[i], "--project=") {
842+
projectOverride = strings.TrimSpace(strings.TrimPrefix(os.Args[i], "--project="))
843+
if projectOverride == "" {
844+
fatal(fmt.Errorf("--project requires a value"))
845+
}
846+
} else if os.Args[i] == "--project" {
847+
if i+1 >= len(os.Args) {
848+
fatal(fmt.Errorf("--project requires a value"))
849+
}
850+
projectOverride = strings.TrimSpace(os.Args[i+1])
851+
if projectOverride == "" {
852+
fatal(fmt.Errorf("--project requires a value"))
853+
}
854+
i++
841855
}
842856
}
843857

@@ -865,7 +879,7 @@ func cmdMCP(cfg store.Config) {
865879
}
866880
defer stopAutosync()
867881

868-
mcpCfg := mcp.MCPConfig{}
882+
mcpCfg := mcp.MCPConfig{DefaultProject: projectOverride}
869883
allowlist := resolveMCPTools(toolsFilter)
870884
mcpSrv := newMCPServerWithConfig(s, mcpCfg, allowlist)
871885

cmd/engram/main_test.go

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -984,9 +984,6 @@ func TestCmdProjectsAllNoGroups(t *testing.T) {
984984
}
985985

986986
func TestCmdMCPDetectsProjectFromFlag(t *testing.T) {
987-
// JR2-4: --project flag is no longer used (dead code removed). The --project flag
988-
// is now silently ignored; project is auto-detected from cwd at each MCP call.
989-
// This test verifies cmdMCP still starts correctly when an unknown flag is passed.
990987
cfg := testConfig(t)
991988

992989
var capturedCfg mcp.MCPConfig
@@ -1008,9 +1005,9 @@ func TestCmdMCPDetectsProjectFromFlag(t *testing.T) {
10081005
withArgs(t, "engram", "mcp", "--project=myproject")
10091006
_, _ = captureOutput(t, func() { cmdMCP(cfg) })
10101007

1011-
// JW6: MCPConfig.DefaultProject removed — verify cmdMCP still calls newMCPServerWithConfig.
1012-
// The project flag is parsed but project is now auto-detected per call, not stored in config.
1013-
_ = capturedCfg // MCPConfig{} — no fields to assert
1008+
if capturedCfg.DefaultProject != "myproject" {
1009+
t.Fatalf("DefaultProject = %q; want myproject", capturedCfg.DefaultProject)
1010+
}
10141011
}
10151012

10161013
func TestCmdMCPDetectsProjectFromEnv(t *testing.T) {
@@ -1035,8 +1032,9 @@ func TestCmdMCPDetectsProjectFromEnv(t *testing.T) {
10351032
withArgs(t, "engram", "mcp")
10361033
_, _ = captureOutput(t, func() { cmdMCP(cfg) })
10371034

1038-
// JW6: MCPConfig.DefaultProject removed — just verify cmdMCP completes without panic.
1039-
_ = capturedCfg
1035+
if capturedCfg.DefaultProject != "env-project" {
1036+
t.Fatalf("DefaultProject = %q; want env-project", capturedCfg.DefaultProject)
1037+
}
10401038
}
10411039

10421040
func TestCmdMCPDetectsProjectFromGit(t *testing.T) {
@@ -1064,8 +1062,9 @@ func TestCmdMCPDetectsProjectFromGit(t *testing.T) {
10641062
withArgs(t, "engram", "mcp")
10651063
_, _ = captureOutput(t, func() { cmdMCP(cfg) })
10661064

1067-
// JW6: MCPConfig.DefaultProject removed — just verify cmdMCP completes without panic.
1068-
_ = capturedCfg
1065+
if capturedCfg.DefaultProject != "" {
1066+
t.Fatalf("DefaultProject = %q; want empty without flag/env", capturedCfg.DefaultProject)
1067+
}
10691068
}
10701069

10711070
func TestCmdSyncUsesDetectProject(t *testing.T) {

internal/mcp/mcp.go

Lines changed: 80 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,16 @@ import (
3030
"github.com/mark3labs/mcp-go/server"
3131
)
3232

33+
const sourceProcessOverride = "process_override"
34+
3335
// MCPConfig holds configuration for the MCP server.
34-
// JW6: DefaultProject removed — it was populated but never read (dead code).
35-
// Project is always auto-detected from cwd at call time via resolveWriteProject/resolveReadProject.
3636
type MCPConfig struct {
37+
// DefaultProject is a trusted process-level project override supplied by
38+
// long-lived MCP hosts (for example, `engram mcp --project NAME` or
39+
// ENGRAM_PROJECT). When set, it is used before cwd detection for MCP
40+
// auto-resolution; per-call project arguments remain separately validated.
41+
DefaultProject string
42+
3743
// BM25Floor overrides the default BM25 score floor used by FindCandidates
3844
// during conflict candidate detection (REQ-001). The floor is the minimum
3945
// acceptable BM25 rank (negative; closer to 0 = better match). Candidates
@@ -503,7 +509,7 @@ Examples:
503509
mcp.Description("Project to echo in envelope context (omit for auto-detect; stats themselves are global aggregates)"),
504510
),
505511
),
506-
handleStats(s),
512+
handleStats(s, cfg),
507513
)
508514
}
509515

@@ -532,7 +538,7 @@ Examples:
532538
mcp.Description("Filter by project name (omit for auto-detect)"),
533539
),
534540
),
535-
handleTimeline(s),
541+
handleTimeline(s, cfg),
536542
)
537543
}
538544

@@ -551,7 +557,7 @@ Examples:
551557
mcp.Description("The observation ID to retrieve"),
552558
),
553559
),
554-
handleGetObservation(s),
560+
handleGetObservation(s, cfg),
555561
)
556562
}
557563

@@ -721,7 +727,7 @@ Duplicates are automatically detected and skipped — safe to call multiple time
721727
mcp.WithIdempotentHintAnnotation(true),
722728
mcp.WithOpenWorldHintAnnotation(false),
723729
),
724-
handleCurrentProject(s),
730+
handleCurrentProject(s, cfg),
725731
)
726732
}
727733

@@ -739,7 +745,7 @@ Duplicates are automatically detected and skipped — safe to call multiple time
739745
mcp.WithString("project", mcp.Description("Project to diagnose (omit for auto-detect)")),
740746
mcp.WithString("check", mcp.Description("Optional diagnostic check code to run")),
741747
),
742-
handleDoctor(s),
748+
handleDoctor(s, cfg),
743749
)
744750
}
745751

@@ -860,10 +866,13 @@ ERROR: Returns IsError=true if IDs are unknown, relation is invalid, or cross-pr
860866
// handleCurrentProject implements mem_current_project. It NEVER returns an error
861867
// even on ambiguous cwd — it always returns a success result with whatever
862868
// detection info is available (REQ-313).
863-
func handleCurrentProject(s *store.Store) server.ToolHandlerFunc {
869+
func handleCurrentProject(s *store.Store, cfg MCPConfig) server.ToolHandlerFunc {
864870
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
865871
cwd, _ := os.Getwd()
866872
res := projectpkg.DetectProjectFull(cwd)
873+
if processRes, ok := processProjectResult(cfg.DefaultProject); ok {
874+
res = processRes
875+
}
867876

868877
envelope := map[string]any{
869878
"project": res.Project,
@@ -893,7 +902,7 @@ func handleSearch(s *store.Store, cfg MCPConfig, activity *SessionActivity) serv
893902
limit := intArg(req, "limit", 10)
894903

895904
// Resolve project: validate override or auto-detect (REQ-310, REQ-311)
896-
detRes, err := resolveReadProject(s, projectOverride)
905+
detRes, err := resolveReadProjectWithProcessOverride(s, projectOverride, cfg.DefaultProject)
897906
if err != nil {
898907
var upe *unknownProjectError
899908
if errors.As(err, &upe) {
@@ -1052,8 +1061,8 @@ func handleSave(s *store.Store, cfg MCPConfig, activity *SessionActivity) server
10521061
}
10531062

10541063
// Resolve write project using the full MCP precedence: explicit request,
1055-
// existing session association, repo config/directory detection, then cwd fallback.
1056-
detRes, err := resolveSaveWriteProject(s, projectChoice, explicitProjectProvided, projectChoiceReason, sessionID, validateRecoveryToken)
1064+
// existing session association, process override, repo config/directory detection, then cwd fallback.
1065+
detRes, err := resolveSaveWriteProjectWithProcessOverride(s, projectChoice, explicitProjectProvided, projectChoiceReason, sessionID, validateRecoveryToken, cfg.DefaultProject)
10571066
if err != nil {
10581067
return writeProjectErrorResult(activity, recoverySessionID, detRes, err), nil
10591068
}
@@ -1310,7 +1319,7 @@ func handleSavePrompt(s *store.Store, cfg MCPConfig, activity *SessionActivity)
13101319
return true, activity.ValidateAmbiguousProjectRecoveryToken(recoverySessionID, recoveryToken, strings.TrimSpace(choice), res.AvailableProjects, res.Path)
13111320
}
13121321

1313-
detRes, err := resolveWriteProjectWithChoice(projectChoice, projectChoiceReason, validateRecoveryToken)
1322+
detRes, err := resolveWriteProjectWithChoiceAndProcessOverride(projectChoice, projectChoiceReason, validateRecoveryToken, cfg.DefaultProject)
13141323
if err != nil {
13151324
return writeProjectErrorResult(activity, recoverySessionID, detRes, err), nil
13161325
}
@@ -1347,7 +1356,7 @@ func handleContext(s *store.Store, cfg MCPConfig, activity *SessionActivity) ser
13471356
scope, _ := req.GetArguments()["scope"].(string)
13481357

13491358
// Resolve project: validate override or auto-detect (REQ-310, REQ-311)
1350-
detRes, err := resolveReadProject(s, projectOverride)
1359+
detRes, err := resolveReadProjectWithProcessOverride(s, projectOverride, cfg.DefaultProject)
13511360
if err != nil {
13521361
var upe *unknownProjectError
13531362
if errors.As(err, &upe) {
@@ -1393,12 +1402,12 @@ func handleContext(s *store.Store, cfg MCPConfig, activity *SessionActivity) ser
13931402
}
13941403
}
13951404

1396-
func handleStats(s *store.Store) server.ToolHandlerFunc {
1405+
func handleStats(s *store.Store, cfg MCPConfig) server.ToolHandlerFunc {
13971406
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
13981407
projectOverride, _ := req.GetArguments()["project"].(string)
13991408

14001409
// Resolve project: validate override or auto-detect (REQ-310, REQ-311, REQ-314)
1401-
detRes, err := resolveReadProject(s, projectOverride)
1410+
detRes, err := resolveReadProjectWithProcessOverride(s, projectOverride, cfg.DefaultProject)
14021411
if err != nil {
14031412
var upe *unknownProjectError
14041413
if errors.As(err, &upe) {
@@ -1430,14 +1439,14 @@ func handleStats(s *store.Store) server.ToolHandlerFunc {
14301439
}
14311440

14321441
func DoctorToolHandler(s *store.Store) server.ToolHandlerFunc {
1433-
return handleDoctor(s)
1442+
return handleDoctor(s, MCPConfig{})
14341443
}
14351444

1436-
func handleDoctor(s *store.Store) server.ToolHandlerFunc {
1445+
func handleDoctor(s *store.Store, cfg MCPConfig) server.ToolHandlerFunc {
14371446
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
14381447
projectOverride, _ := req.GetArguments()["project"].(string)
14391448
check, _ := req.GetArguments()["check"].(string)
1440-
detRes, err := resolveReadProject(s, projectOverride)
1449+
detRes, err := resolveReadProjectWithProcessOverride(s, projectOverride, cfg.DefaultProject)
14411450
if err != nil {
14421451
var upe *unknownProjectError
14431452
if errors.As(err, &upe) {
@@ -1470,7 +1479,7 @@ func handleDoctor(s *store.Store) server.ToolHandlerFunc {
14701479
}
14711480
}
14721481

1473-
func handleTimeline(s *store.Store) server.ToolHandlerFunc {
1482+
func handleTimeline(s *store.Store, cfg MCPConfig) server.ToolHandlerFunc {
14741483
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
14751484
observationID := int64(intArg(req, "observation_id", 0))
14761485
if observationID == 0 {
@@ -1481,7 +1490,7 @@ func handleTimeline(s *store.Store) server.ToolHandlerFunc {
14811490
projectOverride, _ := req.GetArguments()["project"].(string)
14821491

14831492
// Resolve project: validate override or auto-detect (REQ-310, REQ-311, REQ-314)
1484-
detRes, err := resolveReadProject(s, projectOverride)
1493+
detRes, err := resolveReadProjectWithProcessOverride(s, projectOverride, cfg.DefaultProject)
14851494
if err != nil {
14861495
var upe *unknownProjectError
14871496
if errors.As(err, &upe) {
@@ -1536,7 +1545,7 @@ func handleTimeline(s *store.Store) server.ToolHandlerFunc {
15361545
}
15371546
}
15381547

1539-
func handleGetObservation(s *store.Store) server.ToolHandlerFunc {
1548+
func handleGetObservation(s *store.Store, cfg MCPConfig) server.ToolHandlerFunc {
15401549
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
15411550
id := int64(intArg(req, "id", 0))
15421551
if id == 0 {
@@ -1548,10 +1557,10 @@ func handleGetObservation(s *store.Store) server.ToolHandlerFunc {
15481557
return mcp.NewToolResultError(fmt.Sprintf("Observation #%d not found", id)), nil
15491558
}
15501559

1551-
// Resolve project from cwd (REQ-310, REQ-314). No override possible for
1552-
// get-by-ID — always auto-detect. JW5: use resolveReadProject (read semantics).
1553-
// Tolerant: don't fail the fetch on resolution error; degrade to plain text.
1554-
detRes, detErr := resolveReadProject(s, "")
1560+
// Resolve project from process override/cwd (REQ-310, REQ-314). No per-call
1561+
// override possible for get-by-ID. Tolerant: don't fail the fetch on
1562+
// resolution error; degrade to plain text.
1563+
detRes, detErr := resolveReadProjectWithProcessOverride(s, "", cfg.DefaultProject)
15551564

15561565
obsProject := ""
15571566
if obs.Project != nil {
@@ -2024,8 +2033,36 @@ func resolveWriteProject() (projectpkg.DetectionResult, error) {
20242033
return res, nil
20252034
}
20262035

2036+
func processProjectResult(project string) (projectpkg.DetectionResult, bool) {
2037+
project = strings.TrimSpace(project)
2038+
if project == "" {
2039+
return projectpkg.DetectionResult{}, false
2040+
}
2041+
normalized, warning := store.NormalizeProject(project)
2042+
return projectpkg.DetectionResult{
2043+
Project: normalized,
2044+
Source: sourceProcessOverride,
2045+
Path: "",
2046+
Warning: warning,
2047+
}, true
2048+
}
2049+
2050+
func resolveWriteProjectWithProcessOverride(defaultProject string) (projectpkg.DetectionResult, error) {
2051+
if res, ok := processProjectResult(defaultProject); ok {
2052+
return res, nil
2053+
}
2054+
return resolveWriteProject()
2055+
}
2056+
20272057
type ambiguousRecoveryTokenValidator func(projectpkg.DetectionResult, string) (provided bool, valid bool)
20282058

2059+
func resolveWriteProjectWithChoiceAndProcessOverride(projectChoice, reason string, validateToken ambiguousRecoveryTokenValidator, defaultProject string) (projectpkg.DetectionResult, error) {
2060+
if strings.TrimSpace(projectChoice) == "" {
2061+
return resolveWriteProjectWithProcessOverride(defaultProject)
2062+
}
2063+
return resolveWriteProjectWithChoice(projectChoice, reason, validateToken)
2064+
}
2065+
20292066
// resolveWriteProjectWithChoice preserves normal write resolution authority and
20302067
// only uses an explicit project choice as a recovery path from ErrAmbiguousProject.
20312068
func resolveWriteProjectWithChoice(projectChoice, reason string, validateToken ambiguousRecoveryTokenValidator) (projectpkg.DetectionResult, error) {
@@ -2081,6 +2118,15 @@ func resolveWriteProjectWithChoice(projectChoice, reason string, validateToken a
20812118
return res, nil
20822119
}
20832120

2121+
func resolveSaveWriteProjectWithProcessOverride(s *store.Store, projectChoice string, explicitProjectProvided bool, reason, sessionID string, validateToken ambiguousRecoveryTokenValidator, defaultProject string) (projectpkg.DetectionResult, error) {
2122+
if !explicitProjectProvided && strings.TrimSpace(projectChoice) == "" && strings.TrimSpace(sessionID) == "" && strings.TrimSpace(reason) == "" {
2123+
if processRes, ok := processProjectResult(defaultProject); ok {
2124+
return processRes, nil
2125+
}
2126+
}
2127+
return resolveSaveWriteProject(s, projectChoice, explicitProjectProvided, reason, sessionID, validateToken)
2128+
}
2129+
20842130
func resolveSaveWriteProject(s *store.Store, projectChoice string, explicitProjectProvided bool, reason, sessionID string, validateToken ambiguousRecoveryTokenValidator) (projectpkg.DetectionResult, error) {
20852131
trimmedSessionID := strings.TrimSpace(sessionID)
20862132
trimmedProjectChoice := strings.TrimSpace(projectChoice)
@@ -2408,6 +2454,15 @@ func resolveAmbiguousChoicePath(ambiguousParent, choice string) string {
24082454
// If override is empty, falls back to auto-detection from cwd.
24092455
// JW2: normalizes the override (lowercase+trim) before ProjectExists lookup so
24102456
// that e.g. "MyApp" and " myapp " both resolve to the stored "myapp".
2457+
func resolveReadProjectWithProcessOverride(s *store.Store, override, defaultProject string) (projectpkg.DetectionResult, error) {
2458+
if strings.TrimSpace(override) == "" {
2459+
if res, ok := processProjectResult(defaultProject); ok {
2460+
return res, nil
2461+
}
2462+
}
2463+
return resolveReadProject(s, override)
2464+
}
2465+
24112466
func resolveReadProject(s *store.Store, override string) (projectpkg.DetectionResult, error) {
24122467
override = strings.TrimSpace(override)
24132468
if override == "" {

0 commit comments

Comments
 (0)