Skip to content

Commit db1e561

Browse files
authored
fix: security and correctness hardening across API, tools, MCP, persistence, and gateways (#59)
* fix: harden security and correctness across API, tools, and MCP Critical: - api/server.go: add ReadHeaderTimeout/Read/Write/Idle timeouts (slowloris DoS fix); add validateAuthConfig loopback auth gate; Stop uses 15s bounded timeout instead of unbounded ctx - tool/file_write.go: replace os.WriteFile with atomic temp-file -> sync -> rename (crash-safe writes); log backup failures via slog.Warn instead of silently dropping them High: - tool/bash.go: replace cmd.CombinedOutput with limitedWriter that caps memory at 500 KB (yes / cat /dev/urandom can no longer OOM the process before the timeout kills the command) - tool/safety.go: validateURLPublic now returns originalHost so callers can preserve the Host header for virtual-host routing; add 0.0.0.0/8, 100.64.0.0/10, 198.18.0.0/15 SSRF ranges; check IPv4-mapped IPv6 (::ffff:x.x.x.x) against IPv4 CIDR blocks - tool/web_fetch.go + download.go: set req.Host to preserve original hostname after IP pinning - mcp/mcp.go: time.After -> time.NewTimer + Stop (avoid timer leak); preserve JSON-RPC error code/message in returned error via pendErrors map; log scanner.Err() with slog.Warn on oversized responses Medium: - permissions/guardian.go: isBase64Injection uses byte iteration (consistent with len(s) byte count) instead of mixing rune count with byte length * fix: eliminate timer leaks in retry, add security headers, bump eyrie submodule - tool/retry.go: replace time.After with time.NewTimer+Stop in RetryExecutor backoff wait - resilience/retry/retry.go: same fix in Do() and DoWithResult() backoff waits - api/server.go: add securityHeaders middleware (X-Content-Type- Options, X-Frame-Options, Cache-Control) on all routes - external/eyrie: bump submodule to c5ab1f0 — picks up eyrie's timer-leak fixes (ratelimit, adaptive_ratelimit), security headers, and all prior security/correctness fixes (loopback auth guard, stream close, guardrails safe variants, etc.) * fix: eliminate timer leak in engine stream retry backoff - engine/stream.go: replace time.After with time.NewTimer+Stop in the stream retry backoff wait (line 464) to avoid leaking the timer in the runtime when ctx is cancelled before the delay elapses * chore(submodules): bump eyrie to 4851357 (router timer leak fix) * fix: eliminate ratelimit timer leak, bound snapshot goroutine, bump eyrie - resilience/ratelimit/ratelimit.go: replace time.After with time.NewTimer+Stop in Wait() backoff loop - engine/stream.go: add 30s timeout context to snapshot Track goroutine (was fire-and-forget with no timeout) - external/eyrie: bump submodule to 356184a — picks up centralized httputil package, bounded keyring lookups, and all prior fixes * fix: harden persistence, sandbox proxy, and daemon gateways Session persistence: - persist.go: atomic write with sync before rename (was os.WriteFile without sync — crash could leave partial file) - sqlite_store.go: add SetMaxOpenConns(1) + busy_timeout=5000 pragma to prevent 'database is locked' under concurrent access Sandbox: - netproxy.go: add HTTP server timeouts (ReadHeader 10s, Read 30s, Write 5min for CONNECT tunnels, Idle 120s) — was no timeouts, vulnerable to slowloris Daemon gateways: - telegram.go: replace fmt.Sprintf JSON injection with json.Marshal; bound response body read to 1 MiB; rune-safe truncation at 4000 - discord.go: bound error response body read to 4 KiB - gateway.go: bound response body read to 1 MiB * style: gofumpt formatting for mcp.go * fix: address review findings in security hardening PR - bash.go: cap limitedWriter at maxOutputBytes+1 so TruncateOutput's truncation marker fires when output exceeds the cap (was silently lost) - stream.go: thread snapCtx into TrackCtx so the 30s snapshot timeout actually bounds git operations (was dead code; Track ignored context) - snapshot.go: add TrackCtx/gitWorkCtx/gitWorkOutputCtx to accept context - safety.go: remove dead IPv4-mapped IPv6 checkIPs branch (net.IPNet.Contains already handles mapped addresses via To4) - mcp.go: only delete pendErrors for undelivered requests in readLoop cleanup, preserving error details for already-signaled callers * chore(submodules): bump eyrie to 138b60d (review fixes)
1 parent b5c4b16 commit db1e561

22 files changed

Lines changed: 329 additions & 85 deletions

File tree

internal/api/server.go

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@ import (
55
"context"
66
"crypto/subtle"
77
"encoding/json"
8+
"fmt"
89
"io"
910
"net"
1011
"net/http"
1112
"strings"
1213
"sync"
14+
"time"
1315
)
1416

1517
const maxRequestBodyBytes = 1 << 20
@@ -74,9 +76,19 @@ func NewWithAPIKey(addr, apiKey string) *Server {
7476

7577
// registerRoutes sets up the HTTP endpoints.
7678
func (s *Server) registerRoutes() {
77-
s.mux.HandleFunc("GET /health", s.handleHealth)
78-
s.mux.HandleFunc("GET /version", s.handleVersion)
79-
s.mux.HandleFunc("POST /chat", s.auth(s.handleChat))
79+
s.mux.HandleFunc("GET /health", securityHeaders(s.handleHealth))
80+
s.mux.HandleFunc("GET /version", securityHeaders(s.handleVersion))
81+
s.mux.HandleFunc("POST /chat", securityHeaders(s.auth(s.handleChat)))
82+
}
83+
84+
// securityHeaders sets standard HTTP security headers on every response.
85+
func securityHeaders(next http.HandlerFunc) http.HandlerFunc {
86+
return func(w http.ResponseWriter, r *http.Request) {
87+
w.Header().Set("X-Content-Type-Options", "nosniff")
88+
w.Header().Set("X-Frame-Options", "DENY")
89+
w.Header().Set("Cache-Control", "no-store")
90+
next(w, r)
91+
}
8092
}
8193

8294
func (s *Server) auth(next http.HandlerFunc) http.HandlerFunc {
@@ -110,6 +122,35 @@ func constantTimeEqual(a, b string) bool {
110122
return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1
111123
}
112124

125+
// validateAuthConfig refuses to start the server with no API key on a
126+
// non-loopback bind. The auth middleware silently allows every request when
127+
// the API key is empty, so a misconfigured server would be wide open. The
128+
// only safe no-key mode is loopback bind.
129+
func (s *Server) validateAuthConfig() error {
130+
if s.apiKey != "" {
131+
return nil
132+
}
133+
host, _, err := net.SplitHostPort(s.addr)
134+
if err != nil {
135+
return fmt.Errorf("api: invalid bind address %q: %w", s.addr, err)
136+
}
137+
if !isLoopbackHost(host) {
138+
return fmt.Errorf("api: apiKey is empty and bind address %q is not loopback; refusing to start. Set apiKey or bind to 127.0.0.1", s.addr)
139+
}
140+
return nil
141+
}
142+
143+
// isLoopbackHost reports whether host is a loopback address.
144+
func isLoopbackHost(host string) bool {
145+
if host == "" || host == "localhost" {
146+
return host == "localhost" // "" is unsafe; "localhost" is loopback
147+
}
148+
if ip := net.ParseIP(host); ip != nil {
149+
return ip.IsLoopback()
150+
}
151+
return false
152+
}
153+
113154
func decodeJSONBody(w http.ResponseWriter, r *http.Request, dst any) bool {
114155
r.Body = http.MaxBytesReader(w, r.Body, maxRequestBodyBytes)
115156
dec := json.NewDecoder(r.Body)
@@ -127,10 +168,17 @@ func decodeJSONBody(w http.ResponseWriter, r *http.Request, dst any) bool {
127168

128169
// Start starts the HTTP server. It blocks until the context is cancelled or an error occurs.
129170
func (s *Server) Start(ctx context.Context) error {
171+
if err := s.validateAuthConfig(); err != nil {
172+
return err
173+
}
130174
s.mu.Lock()
131175
s.server = &http.Server{
132-
Addr: s.addr,
133-
Handler: s.mux,
176+
Addr: s.addr,
177+
Handler: s.mux,
178+
ReadHeaderTimeout: 10 * time.Second,
179+
ReadTimeout: 30 * time.Second,
180+
WriteTimeout: 30 * time.Second,
181+
IdleTimeout: 120 * time.Second,
134182
}
135183
s.mu.Unlock()
136184

@@ -151,7 +199,7 @@ func (s *Server) Start(ctx context.Context) error {
151199
return err
152200
}
153201

154-
// Stop gracefully shuts down the HTTP server.
202+
// Stop gracefully shuts down the HTTP server with a 15-second timeout.
155203
func (s *Server) Stop(ctx context.Context) error {
156204
s.mu.Lock()
157205
srv := s.server
@@ -160,7 +208,12 @@ func (s *Server) Stop(ctx context.Context) error {
160208
if srv == nil {
161209
return nil
162210
}
163-
return srv.Shutdown(ctx)
211+
// Use a bounded timeout so Stop cannot hang indefinitely if a
212+
// client keeps a connection open. The caller's ctx is respected
213+
// if it has a shorter deadline.
214+
shutdownCtx, cancel := context.WithTimeout(ctx, 15*time.Second)
215+
defer cancel()
216+
return srv.Shutdown(shutdownCtx)
164217
}
165218

166219
// Handler returns the underlying http.Handler for testing purposes.

internal/daemon/discord.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ func (g *DiscordGateway) fetchMessagesREST(ctx context.Context, channelID, after
201201
}
202202
defer func() { _ = resp.Body.Close() }()
203203
if resp.StatusCode != http.StatusOK {
204-
data, _ := io.ReadAll(resp.Body)
204+
data, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
205205
return nil, fmt.Errorf("discord messages: HTTP %d: %s", resp.StatusCode, string(data))
206206
}
207207
var msgs []discordMessage

internal/daemon/gateway.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ func forwardToHawk(ctx context.Context, client *http.Client, daemonAddr, apiKey,
2929
return "", err
3030
}
3131
defer func() { _ = resp.Body.Close() }()
32-
body, _ := io.ReadAll(resp.Body)
32+
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
3333
var chatResp struct {
3434
Response string `json:"response"`
3535
}

internal/daemon/telegram.go

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -152,17 +152,23 @@ func (tg *TelegramGateway) handleMessage(ctx context.Context, msg *TelegramMessa
152152
response = fmt.Sprintf("Error: %v", err)
153153
}
154154

155-
// Format for Telegram (truncate if too long)
156-
if len(response) > 4000 {
157-
response = response[:4000] + "\n\n... (truncated)"
155+
// Format for Telegram (truncate if too long, at rune boundary)
156+
if len([]rune(response)) > 4000 {
157+
response = string([]rune(response)[:4000]) + "\n\n... (truncated)"
158158
}
159159

160160
_ = tg.sendMessage(ctx, msg.Chat.ID, response)
161161
}
162162

163163
func (tg *TelegramGateway) forwardToHawk(ctx context.Context, prompt string) (string, error) {
164-
payload := fmt.Sprintf(`{"prompt":%q}`, prompt)
165-
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tg.DaemonAddr+"/v1/chat", strings.NewReader(payload))
164+
// Use json.Marshal for safe JSON encoding instead of fmt.Sprintf
165+
// with %q, which does not handle all JSON edge cases (e.g., control
166+
// characters, surrogate pairs).
167+
payload, err := json.Marshal(map[string]string{"prompt": prompt})
168+
if err != nil {
169+
return "", fmt.Errorf("encode prompt: %w", err)
170+
}
171+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, tg.DaemonAddr+"/v1/chat", strings.NewReader(string(payload)))
166172
if err != nil {
167173
return "", err
168174
}
@@ -177,7 +183,8 @@ func (tg *TelegramGateway) forwardToHawk(ctx context.Context, prompt string) (st
177183
}
178184
defer func() { _ = resp.Body.Close() }()
179185

180-
body, _ := io.ReadAll(resp.Body)
186+
// Limit response body to 1 MiB to prevent memory exhaustion.
187+
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
181188
var chatResp struct {
182189
Response string `json:"response"`
183190
}

internal/engine/session.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ type MemoryRecaller interface {
3535
// SnapshotTracker abstracts the snapshot system so engine doesn't import snapshot directly.
3636
type SnapshotTracker interface {
3737
Track(message string) (string, error)
38+
TrackCtx(ctx context.Context, message string) (string, error)
3839
}
3940

4041
// Session manages a conversation with an LLM via eyrie.

internal/engine/stream.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -460,9 +460,11 @@ func (s *Session) agentLoop(ctx context.Context, ch chan<- StreamEvent) {
460460
"reason": retryReason,
461461
"error": streamErr.Error(),
462462
})
463+
retryTimer := time.NewTimer(time.Duration(streamAttempt+1) * time.Second)
463464
select {
464-
case <-time.After(time.Duration(streamAttempt+1) * time.Second):
465+
case <-retryTimer.C:
465466
case <-ctx.Done():
467+
retryTimer.Stop()
466468
ch <- StreamEvent{Type: "error", Content: "stream retry cancelled: " + ctx.Err().Error()}
467469
result.Close()
468470
return
@@ -730,7 +732,13 @@ func (s *Session) agentLoop(ctx context.Context, ch chan<- StreamEvent) {
730732
}
731733
}
732734
if len(writeNames) > 0 {
733-
go func() { _, _ = s.Snapshots.Track(strings.Join(writeNames, ", ")) }()
735+
go func() {
736+
// Bound the snapshot so a slow filesystem doesn't
737+
// leak a goroutine after the session ends.
738+
snapCtx, snapCancel := context.WithTimeout(context.Background(), 30*time.Second)
739+
defer snapCancel()
740+
_, _ = s.Snapshots.TrackCtx(snapCtx, strings.Join(writeNames, ", "))
741+
}()
734742
}
735743
}
736744

internal/mcp/mcp.go

Lines changed: 55 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"encoding/json"
77
"fmt"
88
"io"
9+
"log/slog"
910
"os/exec"
1011
"strings"
1112
"sync"
@@ -24,17 +25,18 @@ func SetClientVersion(v string) { clientVersion = v }
2425

2526
// Server represents a connected MCP server.
2627
type Server struct {
27-
Name string
28-
Command string
29-
Args []string
30-
cmd *exec.Cmd
31-
stdin io.WriteCloser
32-
stdout io.ReadCloser
33-
mu sync.Mutex
34-
nextID int
35-
reader *bufio.Scanner
36-
pending map[int]chan json.RawMessage // response channels keyed by request ID
37-
pendMu sync.Mutex
28+
Name string
29+
Command string
30+
Args []string
31+
cmd *exec.Cmd
32+
stdin io.WriteCloser
33+
stdout io.ReadCloser
34+
mu sync.Mutex
35+
nextID int
36+
reader *bufio.Scanner
37+
pending map[int]chan json.RawMessage // response channels keyed by request ID
38+
pendErrors map[int]string // error details keyed by request ID
39+
pendMu sync.Mutex
3840
}
3941

4042
// Tool is a tool exposed by an MCP server.
@@ -91,14 +93,15 @@ func Connect(ctx context.Context, name, command string, args ...string) (*Server
9193
scanner.Buffer(make([]byte, 0, 1024*1024), 1024*1024) // 1MB buffer
9294

9395
s := &Server{
94-
Name: name,
95-
Command: command,
96-
Args: args,
97-
cmd: cmd,
98-
stdin: stdin,
99-
stdout: stdout,
100-
reader: scanner,
101-
pending: make(map[int]chan json.RawMessage),
96+
Name: name,
97+
Command: command,
98+
Args: args,
99+
cmd: cmd,
100+
stdin: stdin,
101+
stdout: stdout,
102+
reader: scanner,
103+
pending: make(map[int]chan json.RawMessage),
104+
pendErrors: make(map[int]string),
102105
}
103106

104107
// Start background reader to dispatch responses and notifications
@@ -142,6 +145,11 @@ func (s *Server) readLoop() {
142145
s.pendMu.Unlock()
143146
if ok {
144147
if msg.Error != nil {
148+
// Store error details so the caller can include them
149+
// in the returned error instead of a generic message.
150+
s.pendMu.Lock()
151+
s.pendErrors[msg.ID] = fmt.Sprintf("code %d: %s", msg.Error.Code, msg.Error.Message)
152+
s.pendMu.Unlock()
145153
ch <- nil // signal error via nil
146154
} else {
147155
ch <- msg.Result
@@ -152,11 +160,19 @@ func (s *Server) readLoop() {
152160
}
153161
// Otherwise it's a notification — ignore for now
154162
}
155-
// Scanner done — close all pending channels
163+
// Scanner done — log the cause if it was an error (e.g., oversized
164+
// response exceeding the 1MB buffer), then close all pending channels.
165+
if err := s.reader.Err(); err != nil {
166+
slog.Warn("mcp: stdout reader stopped", "server", s.Name, "error", err)
167+
}
156168
s.pendMu.Lock()
157169
for id, ch := range s.pending {
158170
close(ch)
159171
delete(s.pending, id)
172+
// Clean up pendErrors only for requests that will never be
173+
// answered. Entries for already-signaled requests (no longer in
174+
// s.pending) are left for the caller to reap.
175+
delete(s.pendErrors, id)
160176
}
161177
s.pendMu.Unlock()
162178
}
@@ -309,6 +325,7 @@ func (s *Server) callWithTimeout(ctx context.Context, method string, params inte
309325
if err != nil {
310326
s.pendMu.Lock()
311327
delete(s.pending, id)
328+
delete(s.pendErrors, id)
312329
s.pendMu.Unlock()
313330
return nil, fmt.Errorf("write: %w", err)
314331
}
@@ -319,23 +336,40 @@ func (s *Server) callWithTimeout(ctx context.Context, method string, params inte
319336
timeout = time.Until(deadline)
320337
}
321338

339+
// Use time.NewTimer + Stop instead of time.After to avoid leaking
340+
// the timer in the runtime when the response arrives or ctx is
341+
// cancelled before the timeout fires.
342+
timer := time.NewTimer(timeout)
322343
select {
323344
case result, ok := <-ch:
345+
timer.Stop()
324346
if !ok {
325347
return nil, fmt.Errorf("mcp: connection closed")
326348
}
327349
if result == nil {
350+
// Include the server's error code and message if available,
351+
// instead of a generic "server returned error" with no detail.
352+
s.pendMu.Lock()
353+
errMsg := s.pendErrors[id]
354+
delete(s.pendErrors, id)
355+
s.pendMu.Unlock()
356+
if errMsg != "" {
357+
return nil, fmt.Errorf("mcp: server error: %s", errMsg)
358+
}
328359
return nil, fmt.Errorf("mcp: server returned error")
329360
}
330361
return result, nil
331-
case <-time.After(timeout):
362+
case <-timer.C:
332363
s.pendMu.Lock()
333364
delete(s.pending, id)
365+
delete(s.pendErrors, id)
334366
s.pendMu.Unlock()
335367
return nil, fmt.Errorf("mcp: call %s timed out after %s", method, timeout)
336368
case <-ctx.Done():
369+
timer.Stop()
337370
s.pendMu.Lock()
338371
delete(s.pending, id)
372+
delete(s.pendErrors, id)
339373
s.pendMu.Unlock()
340374
return nil, ctx.Err()
341375
}

internal/permissions/guardian.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,9 +256,12 @@ func isBase64Injection(s string) bool {
256256
if len(s) < minBase64Len {
257257
return false
258258
}
259-
// Check if the line is mostly base64 characters (letters, digits, +, /, =)
259+
// Count base64-legal bytes (all ASCII). Using byte iteration instead
260+
// of rune iteration keeps the count consistent with len(s) (which is
261+
// a byte count), so the ratio is correct for multi-byte UTF-8 input.
260262
b64Chars := 0
261-
for _, c := range s {
263+
for i := 0; i < len(s); i++ {
264+
c := s[i]
262265
if (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '+' || c == '/' || c == '=' {
263266
b64Chars++
264267
}

internal/resilience/ratelimit/ratelimit.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,12 @@ func (l *Limiter) Wait(ctx context.Context) error {
7070
}
7171
l.mu.Unlock()
7272

73+
timer := time.NewTimer(waitTime)
7374
select {
7475
case <-ctx.Done():
76+
timer.Stop()
7577
return ctx.Err()
76-
case <-time.After(waitTime):
78+
case <-timer.C:
7779
// Try again
7880
}
7981
}

0 commit comments

Comments
 (0)