Skip to content

Commit 7658246

Browse files
committed
fix(mcp): panic-guard tool registration via mustAddTool wrapper
1 parent a49d4e7 commit 7658246

2 files changed

Lines changed: 113 additions & 43 deletions

File tree

internal/mcp/tools.go

Lines changed: 68 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -48,63 +48,88 @@ type emptyArgs struct{}
4848

4949
// Register binds all 8 MCP tools onto the provided SDK server.
5050
//
51-
// Tool names are bit-identical to Python `mcp/server/server.py`. Order in
52-
// this Register call is for readability only; the SDK sorts tools
53-
// alphabetically in `tools/list` output.
51+
// Tool names are bit-identical to Python `mcp/server/server.py`. SDK sorts
52+
// tools alphabetically in `tools/list` output, so order here is for readability.
5453
//
55-
// AddTool can panic on schema-inference failure (SDK behavior). Register
56-
// recovers so a misconfigured tool surfaces as a startup error instead of
57-
// taking the process down silently after binding.
54+
// Failure modes that Register surfaces as a startup error (via panic +
55+
// recover):
56+
// 1. mustAddTool name validation (SDK's validateToolName has a log-only
57+
// branch — server.go:238-241 — that we bypass by panicking up-front).
58+
// 2. SDK schema-inference panic (toolForErr).
59+
// 3. SDK schema-shape panic (Server.AddTool).
60+
//
61+
// Result: every registration either succeeds completely or returns an error.
62+
// No silent half-registrations.
5863
func Register(srv *sdkmcp.Server, deps *Deps) (err error) {
5964
defer func() {
6065
if r := recover(); r != nil {
61-
err = fmt.Errorf("mcp.Register: AddTool panic: %v", r)
66+
err = fmt.Errorf("mcp.Register: %v", r)
6267
}
6368
}()
6469

6570
// Write tools (state gate applies in Phase 5).
66-
sdkmcp.AddTool(srv, &sdkmcp.Tool{
67-
Name: "rune_capture",
68-
Description: "Capture a decision record (agent-delegated extraction required).",
69-
}, stubHandler[domain.CaptureRequest, domain.CaptureResponse](deps, "rune_capture"))
70-
71-
sdkmcp.AddTool(srv, &sdkmcp.Tool{
72-
Name: "rune_batch_capture",
73-
Description: "Capture a batch of decision records (e.g. session-end sweep).",
74-
}, stubHandler[service.BatchCaptureArgs, service.BatchCaptureResult](deps, "rune_batch_capture"))
75-
76-
sdkmcp.AddTool(srv, &sdkmcp.Tool{
77-
Name: "rune_recall",
78-
Description: "Query organizational memory by natural-language question.",
79-
}, stubHandler[domain.RecallArgs, domain.RecallResult](deps, "rune_recall"))
80-
81-
sdkmcp.AddTool(srv, &sdkmcp.Tool{
82-
Name: "rune_delete_capture",
83-
Description: "Soft-delete a record by ID (sets status=reverted, re-inserts).",
84-
}, stubHandler[service.DeleteCaptureArgs, service.DeleteCaptureResult](deps, "rune_delete_capture"))
71+
mustAddTool[domain.CaptureRequest, domain.CaptureResponse](srv, deps,
72+
"rune_capture",
73+
"Capture a decision record (agent-delegated extraction required).")
74+
mustAddTool[service.BatchCaptureArgs, service.BatchCaptureResult](srv, deps,
75+
"rune_batch_capture",
76+
"Capture a batch of decision records (e.g. session-end sweep).")
77+
mustAddTool[domain.RecallArgs, domain.RecallResult](srv, deps,
78+
"rune_recall",
79+
"Query organizational memory by natural-language question.")
80+
mustAddTool[service.DeleteCaptureArgs, service.DeleteCaptureResult](srv, deps,
81+
"rune_delete_capture",
82+
"Soft-delete a record by ID (sets status=reverted, re-inserts).")
8583

8684
// Read / diagnostic tools (state gate bypass).
87-
sdkmcp.AddTool(srv, &sdkmcp.Tool{
88-
Name: "rune_capture_history",
89-
Description: "List recent captures from local capture_log.jsonl (read-only).",
90-
}, stubHandler[service.CaptureHistoryArgs, service.CaptureHistoryResult](deps, "rune_capture_history"))
91-
92-
sdkmcp.AddTool(srv, &sdkmcp.Tool{
93-
Name: "rune_vault_status",
94-
Description: "Probe Vault connectivity and report secure-search mode.",
95-
}, stubHandler[emptyArgs, service.VaultStatusResult](deps, "rune_vault_status"))
85+
mustAddTool[service.CaptureHistoryArgs, service.CaptureHistoryResult](srv, deps,
86+
"rune_capture_history",
87+
"List recent captures from local capture_log.jsonl (read-only).")
88+
mustAddTool[emptyArgs, service.VaultStatusResult](srv, deps,
89+
"rune_vault_status",
90+
"Probe Vault connectivity and report secure-search mode.")
91+
mustAddTool[emptyArgs, service.DiagnosticsResult](srv, deps,
92+
"rune_diagnostics",
93+
"Collect a 7-section health snapshot (env / state / vault / keys / pipelines / embedding / envector).")
94+
mustAddTool[emptyArgs, service.ReloadPipelinesResult](srv, deps,
95+
"rune_reload_pipelines",
96+
"Re-initialize Vault + envector pipelines (BOOT replay) with envector warmup.")
9697

97-
sdkmcp.AddTool(srv, &sdkmcp.Tool{
98-
Name: "rune_diagnostics",
99-
Description: "Collect a 7-section health snapshot (env / state / vault / keys / pipelines / embedding / envector).",
100-
}, stubHandler[emptyArgs, service.DiagnosticsResult](deps, "rune_diagnostics"))
98+
return nil
99+
}
101100

101+
// mustAddTool wraps sdkmcp.AddTool with up-front name validation.
102+
//
103+
// The SDK's Server.AddTool only LOGS on invalid tool names
104+
// (go-sdk/mcp/server.go:238-241) — it does not panic, so Register's
105+
// defer recover() would miss it and the bad-named tool would silently
106+
// register. mustAddTool panics on invalid names, unifying the failure
107+
// path so recover() catches everything.
108+
func mustAddTool[In, Out any](srv *sdkmcp.Server, deps *Deps, name, description string) {
109+
if !isValidToolName(name) {
110+
panic(fmt.Errorf("mustAddTool: invalid tool name %q (allowed: [A-Za-z0-9_-], 1..128 chars)", name))
111+
}
102112
sdkmcp.AddTool(srv, &sdkmcp.Tool{
103-
Name: "rune_reload_pipelines",
104-
Description: "Re-initialize Vault + envector pipelines (BOOT replay) with envector warmup.",
105-
}, stubHandler[emptyArgs, service.ReloadPipelinesResult](deps, "rune_reload_pipelines"))
113+
Name: name,
114+
Description: description,
115+
}, stubHandler[In, Out](deps, name))
116+
}
106117

107-
return nil
118+
// isValidToolName mirrors the SDK's validateToolName rules
119+
// (go-sdk/mcp/tool.go:109): non-empty, ≤128 chars, only [A-Za-z0-9_-].
120+
// Update this when bumping the SDK if its validation tightens.
121+
func isValidToolName(name string) bool {
122+
if name == "" || len(name) > 128 {
123+
return false
124+
}
125+
for _, r := range name {
126+
ok := (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') ||
127+
(r >= '0' && r <= '9') || r == '_' || r == '-'
128+
if !ok {
129+
return false
130+
}
131+
}
132+
return true
108133
}
109134

110135
// stubHandler returns a SDK ToolHandlerFor that always responds with a

internal/mcp/tools_test.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// Internal-package tests for the panic-guard wrapper around sdkmcp.AddTool.
2+
// (Phase A.5 in-memory smoke is in register_test.go which uses package mcp_test.)
3+
4+
package mcp
5+
6+
import (
7+
"strings"
8+
"testing"
9+
10+
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
11+
)
12+
13+
func TestIsValidToolName(t *testing.T) {
14+
if !isValidToolName("rune_capture") {
15+
t.Error("rune_capture should be valid")
16+
}
17+
if isValidToolName("") {
18+
t.Error("empty should be invalid")
19+
}
20+
if isValidToolName("rune capture") {
21+
t.Error("name with space should be invalid")
22+
}
23+
if isValidToolName(strings.Repeat("a", 129)) {
24+
t.Error("name >128 chars should be invalid")
25+
}
26+
}
27+
28+
func TestMustAddTool_PanicsOnInvalidName(t *testing.T) {
29+
defer func() {
30+
if r := recover(); r == nil {
31+
t.Error("mustAddTool with invalid name did not panic")
32+
}
33+
}()
34+
srv := sdkmcp.NewServer(&sdkmcp.Implementation{Name: "x", Version: "0"}, nil)
35+
mustAddTool[emptyArgs, emptyArgs](srv, &Deps{}, "rune capture", "test")
36+
}
37+
38+
func TestRegister_AllHardcodedNamesValid(t *testing.T) {
39+
// Sanity: Register's 8 hardcoded names all pass mustAddTool's check.
40+
// Catches an accidental typo in tools.go before Phase A.5 integration test runs.
41+
srv := sdkmcp.NewServer(&sdkmcp.Implementation{Name: "x", Version: "0"}, nil)
42+
if err := Register(srv, &Deps{}); err != nil {
43+
t.Errorf("Register returned error: %v", err)
44+
}
45+
}

0 commit comments

Comments
 (0)