diff --git a/client/adaptive_ratelimit.go b/client/adaptive_ratelimit.go index 75099ab..ca76188 100644 --- a/client/adaptive_ratelimit.go +++ b/client/adaptive_ratelimit.go @@ -387,10 +387,12 @@ func (a *AdaptiveRateLimitProvider) checkAndWait(ctx context.Context) error { if delay > 0 && delay <= a.config.MaxDelay { a.throttleCount++ a.mu.Unlock() + timer := time.NewTimer(delay) select { case <-ctx.Done(): + timer.Stop() return ctx.Err() - case <-time.After(delay): + case <-timer.C: } a.mu.Lock() now = time.Now() @@ -421,10 +423,12 @@ func (a *AdaptiveRateLimitProvider) checkAndWait(ctx context.Context) error { if delay > 0 && delay <= a.config.MaxDelay { a.throttleCount++ a.mu.Unlock() + timer := time.NewTimer(delay) select { case <-ctx.Done(): + timer.Stop() return ctx.Err() - case <-time.After(delay): + case <-timer.C: } a.mu.Lock() now = time.Now() diff --git a/client/guardrails.go b/client/guardrails.go index 51fb98e..dbae67f 100644 --- a/client/guardrails.go +++ b/client/guardrails.go @@ -63,6 +63,8 @@ type GuardrailViolation struct { Rule GuardrailRule `json:"rule"` MatchedText string `json:"matched_text"` RedactedResult string `json:"redacted_result,omitempty"` + matchStart int // byte offset of the match in the original response + matchEnd int // byte offset one past the last matched byte } // GuardrailError is returned when a guardrail blocks a response. @@ -91,6 +93,10 @@ func NewGuardrails(rules ...GuardrailRule) *Guardrails { } // AddRule registers a guardrail rule. It panics if the pattern is invalid. +// This follows the regexp.MustCompile convention for programmatic rules +// where an invalid pattern indicates a programmer error. For rules that +// may originate from untrusted sources (config files, user input), use +// AddRuleSafe instead. func (g *Guardrails) AddRule(r GuardrailRule) { if r.Pattern != "" { compiled, err := regexp.Compile(r.Pattern) @@ -104,6 +110,37 @@ func (g *Guardrails) AddRule(r GuardrailRule) { g.rules = append(g.rules, r) } +// AddRuleSafe registers a guardrail rule and returns an error if the pattern +// is invalid, instead of panicking. Use this when rules may come from +// untrusted sources (config files, user input). +func (g *Guardrails) AddRuleSafe(r GuardrailRule) error { + if r.Pattern != "" { + compiled, err := regexp.Compile(r.Pattern) + if err != nil { + return fmt.Errorf("eyrie: guardrails: invalid regex %q in rule %q: %w", r.Pattern, r.Name, err) + } + r.compiled = compiled + } + g.mu.Lock() + defer g.mu.Unlock() + g.rules = append(g.rules, r) + return nil +} + +// NewGuardrailsSafe creates a Guardrails instance and returns an error if any +// rule has an invalid pattern. Use this when rules may come from untrusted +// sources; use NewGuardrails for programmatic rules where invalid patterns +// indicate a programmer error (matching regexp.MustCompile convention). +func NewGuardrailsSafe(rules ...GuardrailRule) (*Guardrails, error) { + g := &Guardrails{} + for _, r := range rules { + if err := g.AddRuleSafe(r); err != nil { + return nil, err + } + } + return g, nil +} + // Rules returns a snapshot of the currently registered rules. func (g *Guardrails) Rules() []GuardrailRule { g.mu.RLock() @@ -134,17 +171,19 @@ func (g *Guardrails) Check(ctx context.Context, response string) ([]GuardrailVio if rule.compiled == nil { continue } - matches := rule.compiled.FindAllString(response, -1) + matches := rule.compiled.FindAllStringIndex(response, -1) if len(matches) == 0 { continue } for _, match := range matches { v := GuardrailViolation{ Rule: rule, - MatchedText: match, + MatchedText: response[match[0]:match[1]], + matchStart: match[0], + matchEnd: match[1], } if rule.Action == GuardrailRedact { - v.RedactedResult = strings.Repeat("*", len(match)) + v.RedactedResult = strings.Repeat("*", len(v.MatchedText)) } violations = append(violations, v) if rule.Action == GuardrailBlock { @@ -165,7 +204,9 @@ func (g *Guardrails) Check(ctx context.Context, response string) ([]GuardrailVio // ApplyRedactions takes the response text and violations, replacing redacted // matches with their redaction markers. Non-redact violations are left intact. -// Matches are applied positionally to handle overlapping patterns correctly. +// Match positions are used directly from the violations (captured during Check) +// so the correct instance of each match is redacted even when the matched text +// appears multiple times in the response. func ApplyRedactions(response string, violations []GuardrailViolation) string { type replacement struct { start int @@ -177,11 +218,17 @@ func ApplyRedactions(response string, violations []GuardrailViolation) string { if v.Rule.Action != GuardrailRedact { continue } - idx := strings.Index(response, v.MatchedText) - if idx < 0 { + if v.matchEnd == 0 && v.matchStart == 0 && v.MatchedText != "" { + // Fallback: violation came from outside Check (e.g. constructed + // manually). Search for the first occurrence. + idx := strings.Index(response, v.MatchedText) + if idx < 0 { + continue + } + reps = append(reps, replacement{start: idx, end: idx + len(v.MatchedText), text: v.RedactedResult}) continue } - reps = append(reps, replacement{start: idx, end: idx + len(v.MatchedText), text: v.RedactedResult}) + reps = append(reps, replacement{start: v.matchStart, end: v.matchEnd, text: v.RedactedResult}) } if len(reps) == 0 { return response diff --git a/client/guardrails_test.go b/client/guardrails_test.go index 6c549d0..ee4d551 100644 --- a/client/guardrails_test.go +++ b/client/guardrails_test.go @@ -164,6 +164,42 @@ func TestGuardrails_InvalidPatternPanics(t *testing.T) { }) } +func TestGuardrails_InvalidPatternSafeReturnsError(t *testing.T) { + _, err := NewGuardrailsSafe(GuardrailRule{ + Type: GuardrailCustom, + Name: "bad_regex", + Pattern: `[invalid`, + Action: GuardrailBlock, + }) + if err == nil { + t.Fatal("expected error for invalid regex in NewGuardrailsSafe") + } +} + +func TestGuardrails_AddRuleSafe(t *testing.T) { + g := NewGuardrails() + if err := g.AddRuleSafe(GuardrailRule{ + Type: GuardrailCustom, + Name: "dynamic_rule", + Pattern: `dynamic_pattern`, + Action: GuardrailWarn, + }); err != nil { + t.Fatalf("expected no error, got: %v", err) + } + if len(g.Rules()) != 1 { + t.Fatalf("expected 1 rule after AddRuleSafe, got %d", len(g.Rules())) + } + + if err := g.AddRuleSafe(GuardrailRule{ + Type: GuardrailCustom, + Name: "bad_regex", + Pattern: `[invalid`, + Action: GuardrailBlock, + }); err == nil { + t.Fatal("expected error for invalid regex in AddRuleSafe") + } +} + func TestGuardrails_AddRule(t *testing.T) { g := NewGuardrails() g.AddRule(GuardrailRule{ diff --git a/client/provider_registry.go b/client/provider_registry.go index 849b392..0e03ad2 100644 --- a/client/provider_registry.go +++ b/client/provider_registry.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log/slog" + "time" "github.com/GrayCodeAI/eyrie/config" "github.com/GrayCodeAI/eyrie/credentials" @@ -280,5 +281,13 @@ func ResolveProviderModelEnvOverride(provider string) string { } func resolveEnvSecret(envKey string) string { - return credentials.LookupSecret(context.Background(), envKey) + // Bound the lookup to prevent indefinite stalls when the OS keyring + // is unresponsive (e.g., locked keychain on macOS, D-Bus failure on + // Linux). The keyring itself has a 30s timeout, but resolveEnvSecret + // is called multiple times in sequence during provider construction + // (up to 6 calls for AWS Bedrock), so a per-call cap keeps the total + // stall bounded. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + return credentials.LookupSecret(ctx, envKey) } diff --git a/client/ratelimit.go b/client/ratelimit.go index fdbce73..59e7ef3 100644 --- a/client/ratelimit.go +++ b/client/ratelimit.go @@ -82,10 +82,12 @@ func (b *tokenBucket) wait(ctx context.Context) error { if since < b.minInterval { wait := b.minInterval - since b.mu.Unlock() + timer := time.NewTimer(wait) select { case <-ctx.Done(): + timer.Stop() return fmt.Errorf("eyrie: rate limiter: %w", ctx.Err()) - case <-time.After(wait): + case <-timer.C: } b.mu.Lock() } @@ -100,10 +102,12 @@ func (b *tokenBucket) wait(ctx context.Context) error { waitDur := time.Duration(needed / b.refillRate) b.mu.Unlock() + timer := time.NewTimer(waitDur) select { case <-ctx.Done(): + timer.Stop() return fmt.Errorf("eyrie: rate limiter: %w", ctx.Err()) - case <-time.After(waitDur): + case <-timer.C: } } } diff --git a/client/retry.go b/client/retry.go index d6c72c7..ee74750 100644 --- a/client/retry.go +++ b/client/retry.go @@ -119,10 +119,16 @@ func doWithRetry(ctx context.Context, httpClient *http.Client, req *http.Request "attempt", attempt, "max", rc.MaxRetries, "delay", delay, "url", req.URL.String(), ) + // Use time.NewTimer + Stop instead of time.After to avoid leaking + // the timer in the runtime when ctx is cancelled before the delay + // elapses. time.After allocates a timer that lives until it fires, + // even if the caller has already moved on. + timer := time.NewTimer(delay) select { case <-ctx.Done(): + timer.Stop() return nil, ctx.Err() - case <-time.After(delay): + case <-timer.C: } } diff --git a/client/stream.go b/client/stream.go index 3be9c8e..08ac8c4 100644 --- a/client/stream.go +++ b/client/stream.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "log/slog" + "sort" "strings" "time" ) @@ -428,9 +429,21 @@ func processOpenAIStreamWithOpts(ctx context.Context, sseEvents <-chan SSEEvent, return } toolsEmitted = true - for _, t := range tools { - var args map[string]interface{} - _ = json.Unmarshal([]byte(t.argsBuf.String()), &args) + // Sort by index so tool calls are emitted in the order the + // model produced them, not in random map-iteration order. + indices := make([]int, 0, len(tools)) + for idx := range tools { + indices = append(indices, idx) + } + sort.Ints(indices) + for _, idx := range indices { + t := tools[idx] + args := map[string]interface{}{} + if err := json.Unmarshal([]byte(t.argsBuf.String()), &args); err != nil { + // On parse failure, pass the raw string as _raw so + // the caller sees something rather than a nil map. + args = map[string]interface{}{"_raw": t.argsBuf.String()} + } emit(ctx, ch, EyrieStreamEvent{ Type: "tool_call", ToolCall: &ToolCall{ID: t.id, Name: t.name, Arguments: args}, diff --git a/config/provider_env.go b/config/provider_env.go index c630338..afdd6bc 100644 --- a/config/provider_env.go +++ b/config/provider_env.go @@ -3,6 +3,7 @@ package config import ( "encoding/json" "fmt" + "net/url" "os" "path/filepath" "strings" @@ -286,13 +287,21 @@ func ValidateAPIKey(apiKey, providerName string) string { return "" } -// ValidateBaseURL validates a base URL. +// ValidateBaseURL validates a base URL. Returns an error message if the URL +// is syntactically invalid (unparseable or missing a scheme), or empty if valid. func ValidateBaseURL(baseURL string) string { if baseURL == "" { return "" } - if _, err := os.Stat(baseURL); err == nil { - return "Invalid base URL: " + baseURL + u, err := url.Parse(baseURL) + if err != nil { + return "Invalid base URL: " + baseURL + " (" + err.Error() + ")" + } + if u.Scheme != "http" && u.Scheme != "https" { + return "Invalid base URL: " + baseURL + " (must be http or https)" + } + if u.Host == "" { + return "Invalid base URL: " + baseURL + " (missing host)" } return "" } @@ -360,7 +369,7 @@ func SaveProviderConfig(config *ProviderConfig, path string) error { if path == "" { path = GetProviderConfigPath() } - _ = os.MkdirAll(filepath.Dir(path), 0o755) + _ = os.MkdirAll(filepath.Dir(path), 0o700) data, err := json.MarshalIndent(config, "", " ") if err != nil { return err diff --git a/config/runtime.go b/config/runtime.go index 43a6e5a..4d27b04 100644 --- a/config/runtime.go +++ b/config/runtime.go @@ -4,6 +4,7 @@ import ( "context" "os" "strings" + "time" "github.com/GrayCodeAI/eyrie/credentials" ) @@ -43,8 +44,13 @@ func envValue(key string) string { if key == "" { return "" } + // Bound the keyring lookup to prevent indefinite stalls when the OS + // keychain is unresponsive. The keyring itself has a 30s timeout, + // but envValue is called many times during provider construction. + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() // Always check the credential store first for ALL keys. - if v := credentials.LookupSecret(context.Background(), key); v != "" { + if v := credentials.LookupSecret(ctx, key); v != "" { return v } // Fall back to process environment only when the credential store has diff --git a/conversation/engine.go b/conversation/engine.go index 7ded09b..bf836f3 100644 --- a/conversation/engine.go +++ b/conversation/engine.go @@ -202,6 +202,9 @@ func (e *Engine) DeleteNode(ctx context.Context, id string) error { const defaultGroupBudgetMultiplier = 4 func (e *Engine) streamAndSave(ctx context.Context, parentNode *storage.Node, messages []client.EyrieMessage, opts PromptOpts) (<-chan Event, error) { + if e.provider == nil { + return nil, fmt.Errorf("conversation: engine has no provider") + } _, span := tracer.Start( ctx, "conversation.streamAndSave", trace.WithAttributes( @@ -248,6 +251,18 @@ func (e *Engine) streamAndSave(ctx context.Context, parentNode *storage.Node, me currentStream = sr.Events ) + // Ensure the last StreamResult is always closed on exit. The + // closure captures currentSR by reference, so it closes whatever + // stream is active when the goroutine returns — including the + // initial stream on early exits and the final continuation stream + // on normal completion. Previous streams are closed explicitly + // during the continuation loop (currentSR.Close() before reassign). + defer func() { + if currentSR != nil { + currentSR.Close() + } + }() + for { var fullTextBuilder strings.Builder var usage *client.EyrieUsage diff --git a/credentials/keyring_platform.go b/credentials/keyring_platform.go index 4ff255e..6eaf694 100644 --- a/credentials/keyring_platform.go +++ b/credentials/keyring_platform.go @@ -80,7 +80,16 @@ func (k *keyringStore) Get(ctx context.Context, account string) (string, error) return "", err } if r.err != nil { - return "", ErrNotFound + // Only map "not found" errors to ErrNotFound; pass through real + // backend errors (locked keychain, permission denied, D-Bus + // failure, etc.) so callers can distinguish "no secret stored" + // from "lookup failed". Without this, LookupSecret would log a + // real backend failure at Debug level ("no secret stored") + // instead of Warn level ("secret lookup failed"). + if isNotFound(r.err) { + return "", ErrNotFound + } + return "", r.err } if strings.TrimSpace(r.val) == "" { return "", ErrNotFound diff --git a/internal/api/auth_test.go b/internal/api/auth_test.go index ca918be..3b02ccf 100644 --- a/internal/api/auth_test.go +++ b/internal/api/auth_test.go @@ -1,16 +1,20 @@ package api -import "testing" +import ( + "testing" + + "github.com/GrayCodeAI/eyrie/internal/httputil" +) func TestConstantTimeEqual(t *testing.T) { key := "super-secret-api-key" - if !constantTimeEqual(key, key) { + if !httputil.ConstantTimeEqual(key, key) { t.Fatal("expected equal tokens to match") } - if constantTimeEqual(key, key+"x") { + if httputil.ConstantTimeEqual(key, key+"x") { t.Fatal("expected different-length tokens to not match") } - if constantTimeEqual("short", "much-longer-token") { + if httputil.ConstantTimeEqual("short", "much-longer-token") { t.Fatal("expected mismatched tokens to not match") } } diff --git a/internal/api/openai_proxy.go b/internal/api/openai_proxy.go index a1c5b4e..64048af 100644 --- a/internal/api/openai_proxy.go +++ b/internal/api/openai_proxy.go @@ -101,12 +101,20 @@ type openAIChatChunk struct { Choices []openAIChunkChoice `json:"choices"` } +// maxOpenAIBodyBytes is the body limit for the OpenAI-compatible proxy +// endpoint. Larger than the native 1 MiB limit because OpenAI chat +// completions carry full multi-turn conversations with tool definitions. +const maxOpenAIBodyBytes = 10 << 20 // 10 MiB + // handleOpenAIChatCompletions implements POST /v1/chat/completions. func (s *Server) handleOpenAIChatCompletions(w http.ResponseWriter, r *http.Request) { // OpenAI clients send many fields eyrie does not consume (seed, logprobs, // stream_options, ...). Decode leniently rather than with the strict // unknown-field rejection used by decodeJSONBody. - r.Body = http.MaxBytesReader(w, r.Body, maxRequestBodyBytes) + // Use a larger body limit than the native /prompt endpoint (1 MiB) because + // OpenAI chat completions carry full multi-turn conversations with large + // system prompts and tool definitions. + r.Body = http.MaxBytesReader(w, r.Body, maxOpenAIBodyBytes) var req openAIChatRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request body"}) @@ -229,6 +237,12 @@ func (s *Server) streamOpenAIResponse(w http.ResponseWriter, ctx context.Context _, finish := s.openAIUsageForNode(ctx, nodeID) if errMsg != "" { + // Surface the error as an SSE data event so the client knows the + // generation failed, rather than silently replacing the finish + // reason with "stop". OpenAI clients expect a JSON error object. + errPayload, _ := json.Marshal(map[string]string{"error": errMsg}) + _, _ = fmt.Fprintf(w, "data: %s\n\n", errPayload) + flusher.Flush() finish = "stop" } finishReason := finish diff --git a/internal/api/server.go b/internal/api/server.go index 309ce7c..95f2297 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -2,17 +2,15 @@ package api import ( "context" - "crypto/subtle" "encoding/json" "fmt" - "io" "net/http" - "strings" "time" "github.com/GrayCodeAI/eyrie/client" "github.com/GrayCodeAI/eyrie/conversation" eyrie "github.com/GrayCodeAI/eyrie/internal/health" + "github.com/GrayCodeAI/eyrie/internal/httputil" "github.com/GrayCodeAI/eyrie/storage" ) @@ -29,6 +27,7 @@ type Server struct { mux *http.ServeMux handler http.Handler // traced handler wrapping mux bgCtx context.Context + bgCancel context.CancelFunc // cancelled on Shutdown to release bgCtx httpSrv *http.Server } @@ -48,7 +47,6 @@ type Config struct { func NewServer(cfg Config) *Server { ctx, cancel := context.WithCancel(context.Background()) - _ = cancel // cancelled on server shutdown if needed s := &Server{ engine: conversation.New(cfg.Store, cfg.Provider), store: cfg.Store, @@ -58,9 +56,10 @@ func NewServer(cfg Config) *Server { virtualKeyFor: cfg.VirtualKeyResolver, mux: http.NewServeMux(), bgCtx: ctx, + bgCancel: cancel, } s.routes() - s.handler = TracingMiddleware(s.mux) + s.handler = httputil.SecurityHeaders(TracingMiddleware(s.mux)) return s } @@ -69,6 +68,9 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func (s *Server) ListenAndServe(addr string) error { + if err := httputil.ValidateAuthConfig(addr, s.apiKey); err != nil { + return err + } s.httpSrv = &http.Server{ Addr: addr, Handler: s.handler, @@ -82,6 +84,9 @@ func (s *Server) ListenAndServe(addr string) error { // Shutdown gracefully shuts down the HTTP server without interrupting active connections. func (s *Server) Shutdown() error { + if s.bgCancel != nil { + s.bgCancel() + } if s.httpSrv == nil { return nil } @@ -114,15 +119,11 @@ func (s *Server) routes() { func (s *Server) auth(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - token := r.Header.Get("Authorization") - token = strings.TrimPrefix(token, "Bearer ") - if token == "" { - token = r.Header.Get("X-API-Key") - } + token := httputil.ExtractBearerToken(r) if s.apiKey != "" { - if !constantTimeEqual(token, s.apiKey) { - writeJSON(w, http.StatusUnauthorized, map[string]string{"error": "unauthorized"}) + if !httputil.ConstantTimeEqual(token, s.apiKey) { + httputil.WriteJSON(w, http.StatusUnauthorized, map[string]string{"error": "unauthorized"}) return } } @@ -139,31 +140,6 @@ func (s *Server) auth(next http.HandlerFunc) http.HandlerFunc { } } -func constantTimeEqual(a, b string) bool { - // Pad the shorter value so comparison time does not leak token length. - if len(a) < len(b) { - a += strings.Repeat("\x00", len(b)-len(a)) - } else if len(b) < len(a) { - b += strings.Repeat("\x00", len(a)-len(b)) - } - return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1 -} - -func decodeJSONBody(w http.ResponseWriter, r *http.Request, dst any) bool { - r.Body = http.MaxBytesReader(w, r.Body, maxRequestBodyBytes) - dec := json.NewDecoder(r.Body) - dec.DisallowUnknownFields() - if err := dec.Decode(dst); err != nil { - writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request body"}) - return false - } - if err := dec.Decode(&struct{}{}); err != io.EOF { - writeJSON(w, http.StatusBadRequest, map[string]string{"error": "request body must contain a single JSON object"}) - return false - } - return true -} - func (s *Server) handleHealth(w http.ResponseWriter, _ *http.Request) { writeJSON(w, http.StatusOK, map[string]string{"status": "ok"}) } @@ -362,8 +338,14 @@ func (s *Server) collectAndRespond(w http.ResponseWriter, events <-chan conversa }) } +// writeJSON and decodeJSONBody delegate to internal/httputil for the +// canonical implementation. They remain as package-local wrappers so +// existing call sites in rerank.go, analytics.go, and openai_proxy.go +// don't need to be updated individually. func writeJSON(w http.ResponseWriter, status int, v interface{}) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - _ = json.NewEncoder(w).Encode(v) + httputil.WriteJSON(w, status, v) +} + +func decodeJSONBody(w http.ResponseWriter, r *http.Request, dst any) bool { + return httputil.DecodeJSONBody(w, r, dst) } diff --git a/internal/cache/backend.go b/internal/cache/backend.go index c120a8e..15040df 100644 --- a/internal/cache/backend.go +++ b/internal/cache/backend.go @@ -100,6 +100,11 @@ func (m *MemoryBackend) Delete(_ context.Context, key string) error { // (i.e. a missing key). It is handled internally and never surfaced from Get. var ErrRedisNil = errors.New("eyrie: redis nil reply") +// maxRespBulkLen caps the size of a single RESP bulk-string reply to prevent +// memory-exhaustion from a malicious or buggy Redis server. 64 MB is well +// beyond any legitimate cache value. +const maxRespBulkLen = 64 * 1024 * 1024 + // RedisBackend is a minimal, dependency-free CacheBackend backed by Redis. // // It implements only the subset of commands the cache needs: @@ -284,6 +289,9 @@ func readReply(r *bufio.Reader) ([]byte, error) { if n < 0 { return nil, ErrRedisNil } + if n > maxRespBulkLen { + return nil, fmt.Errorf("eyrie: redis bulk string length %d exceeds max %d", n, maxRespBulkLen) + } buf := make([]byte, n+2) // include trailing CRLF if _, err := readFull(r, buf); err != nil { return nil, err diff --git a/internal/httputil/httputil.go b/internal/httputil/httputil.go new file mode 100644 index 0000000..23dffec --- /dev/null +++ b/internal/httputil/httputil.go @@ -0,0 +1,117 @@ +// Package httputil provides shared HTTP server primitives used across +// eyrie's API surfaces. Centralizing these eliminates drift in auth +// comparison, body decoding, JSON responses, security headers, and +// loopback validation that previously existed as duplicated copies +// in internal/api/server.go and other handlers. +package httputil + +import ( + "crypto/subtle" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "strings" +) + +// MaxRequestBodyBytes is the default maximum request body size (1 MiB). +const MaxRequestBodyBytes = 1 << 20 + +// ConstantTimeEqual compares two strings in constant time. The shorter +// value is padded with NUL bytes so comparison time does not leak token +// length. This is the preferred bearer-token comparison for API servers. +func ConstantTimeEqual(a, b string) bool { + if len(a) < len(b) { + a += strings.Repeat("\x00", len(b)-len(a)) + } else if len(b) < len(a) { + b += strings.Repeat("\x00", len(a)-len(b)) + } + return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1 +} + +// DecodeJSONBody decodes a JSON request body into dst with a size limit +// and strict unknown-field rejection. Returns true on success. On failure +// it writes a 400 JSON error response and returns false. +func DecodeJSONBody(w http.ResponseWriter, r *http.Request, dst any) bool { + return DecodeJSONBodyWithLimit(w, r, dst, MaxRequestBodyBytes) +} + +// DecodeJSONBodyWithLimit is like DecodeJSONBody but with a custom body +// size limit. +func DecodeJSONBodyWithLimit(w http.ResponseWriter, r *http.Request, dst any, maxBytes int64) bool { + r.Body = http.MaxBytesReader(w, r.Body, maxBytes) + dec := json.NewDecoder(r.Body) + dec.DisallowUnknownFields() + if err := dec.Decode(dst); err != nil { + WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid request body"}) + return false + } + if err := dec.Decode(&struct{}{}); err != io.EOF { + WriteJSON(w, http.StatusBadRequest, map[string]string{"error": "request body must contain a single JSON object"}) + return false + } + return true +} + +// WriteJSON writes a JSON response with the given status code. +func WriteJSON(w http.ResponseWriter, status int, v interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(v) +} + +// SecurityHeaders wraps an http.Handler with standard security headers +// (X-Content-Type-Options, X-Frame-Options, Cache-Control). +func SecurityHeaders(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Set("X-Frame-Options", "DENY") + w.Header().Set("Cache-Control", "no-store") + next.ServeHTTP(w, r) + }) +} + +// IsLoopbackHost reports whether host is a loopback address: +// 127.0.0.0/8, ::1, or "localhost". An empty string is treated as +// non-loopback (fail-safe). +func IsLoopbackHost(host string) bool { + if host == "" || host == "localhost" { + return host == "localhost" + } + if ip := net.ParseIP(host); ip != nil { + return ip.IsLoopback() + } + return false +} + +// ValidateAuthConfig refuses to start a server with no API key on a +// non-loopback bind. Returns nil if the API key is set or the bind +// address is loopback. +func ValidateAuthConfig(addr, apiKey string) error { + if apiKey != "" { + return nil + } + host, _, err := net.SplitHostPort(addr) + if err != nil { + return fmt.Errorf("invalid bind address %q: %w", addr, err) + } + if !IsLoopbackHost(host) { + return fmt.Errorf("API key is empty and bind address %q is not loopback; refusing to start. Set an API key or bind to 127.0.0.1", addr) + } + return nil +} + +// ExtractBearerToken extracts a bearer/API-key token from request headers. +// It checks "Authorization: Bearer ..." first, then "X-API-Key". +// The Bearer scheme is matched case-insensitively per RFC 7235. +func ExtractBearerToken(r *http.Request) string { + token := r.Header.Get("Authorization") + if len(token) > 7 && strings.EqualFold(token[:7], "Bearer ") { + return token[7:] + } + if token == "" { + token = r.Header.Get("X-API-Key") + } + return token +} diff --git a/internal/httputil/httputil_test.go b/internal/httputil/httputil_test.go new file mode 100644 index 0000000..3a73e62 --- /dev/null +++ b/internal/httputil/httputil_test.go @@ -0,0 +1,108 @@ +package httputil + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestConstantTimeEqual(t *testing.T) { + key := "super-secret-api-key" + if !ConstantTimeEqual(key, key) { + t.Fatal("expected equal tokens to match") + } + if ConstantTimeEqual(key, key+"x") { + t.Fatal("expected different-length tokens to not match") + } + if ConstantTimeEqual("short", "much-longer-token") { + t.Fatal("expected mismatched tokens to not match") + } + if !ConstantTimeEqual("", "") { + t.Fatal("expected empty strings to match") + } +} + +func TestIsLoopbackHost(t *testing.T) { + cases := []struct { + host string + want bool + }{ + {"127.0.0.1", true}, + {"127.1.2.3", true}, + {"::1", true}, + {"localhost", true}, + {"", false}, + {"0.0.0.0", false}, + {"10.0.0.1", false}, + {"example.com", false}, + } + for _, tc := range cases { + if got := IsLoopbackHost(tc.host); got != tc.want { + t.Errorf("IsLoopbackHost(%q) = %v, want %v", tc.host, got, tc.want) + } + } +} + +func TestValidateAuthConfig(t *testing.T) { + if err := ValidateAuthConfig("127.0.0.1:8080", ""); err != nil { + t.Errorf("loopback with no key should be allowed: %v", err) + } + if err := ValidateAuthConfig("0.0.0.0:8080", "secret"); err != nil { + t.Errorf("non-loopback with key should be allowed: %v", err) + } + if err := ValidateAuthConfig("0.0.0.0:8080", ""); err == nil { + t.Error("non-loopback with no key should be rejected") + } +} + +func TestExtractBearerToken(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("Authorization", "Bearer my-token") + if got := ExtractBearerToken(req); got != "my-token" { + t.Errorf("got %q, want %q", got, "my-token") + } + + req2 := httptest.NewRequest("GET", "/", nil) + req2.Header.Set("X-API-Key", "api-key-token") + if got := ExtractBearerToken(req2); got != "api-key-token" { + t.Errorf("got %q, want %q", got, "api-key-token") + } + + req3 := httptest.NewRequest("GET", "/", nil) + if got := ExtractBearerToken(req3); got != "" { + t.Errorf("got %q, want empty", got) + } +} + +func TestWriteJSON(t *testing.T) { + w := httptest.NewRecorder() + WriteJSON(w, http.StatusOK, map[string]string{"status": "ok"}) + if w.Code != http.StatusOK { + t.Errorf("status = %d, want %d", w.Code, http.StatusOK) + } + if ct := w.Header().Get("Content-Type"); ct != "application/json" { + t.Errorf("content-type = %q, want %q", ct, "application/json") + } + body := w.Body.String() + if !strings.Contains(body, `"status":"ok"`) { + t.Errorf("body = %q, want to contain status:ok", body) + } +} + +func TestSecurityHeaders(t *testing.T) { + handler := SecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + w := httptest.NewRecorder() + handler.ServeHTTP(w, httptest.NewRequest("GET", "/", nil)) + if w.Header().Get("X-Content-Type-Options") != "nosniff" { + t.Error("missing X-Content-Type-Options") + } + if w.Header().Get("X-Frame-Options") != "DENY" { + t.Error("missing X-Frame-Options") + } + if w.Header().Get("Cache-Control") != "no-store" { + t.Error("missing Cache-Control") + } +} diff --git a/internal/observability/audit.go b/internal/observability/audit.go index d90c45d..99c3103 100644 --- a/internal/observability/audit.go +++ b/internal/observability/audit.go @@ -62,7 +62,7 @@ type JSONLFileSink struct { // writes and returns a sink that writes to it. The caller owns the sink and // should call Close when done. func NewJSONLFileSink(path string) (*JSONLFileSink, error) { - f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) + f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o600) if err != nil { return nil, err } diff --git a/router/retry.go b/router/retry.go index 4e0e16a..7f5e769 100644 --- a/router/retry.go +++ b/router/retry.go @@ -59,4 +59,7 @@ func BackoffDelay(attempt int, cfg RetryConfig) time.Duration { return types.BackoffDelay(attempt, cfg.BaseDelay, cfg.MaxDelay) } -var afterFunc = time.After +// newTimer is a variable so tests can inject a fake timer. Uses +// time.NewTimer (not time.After) to avoid leaking the timer in the +// runtime when ctx is cancelled before the delay elapses. +var newTimer = time.NewTimer diff --git a/router/router.go b/router/router.go index a5d10e3..47c9631 100644 --- a/router/router.go +++ b/router/router.go @@ -211,10 +211,12 @@ func (r *Router) chatWithRetry(ctx context.Context, p client.Provider, messages if cfg.OnRetry != nil { cfg.OnRetry(RetryEvent{Err: err, Attempt: attempt + 1, MaxRetries: cfg.MaxRetries, Delay: delay}) } + timer := newTimer(delay) select { case <-ctx.Done(): + timer.Stop() return nil, ctx.Err() - case <-afterFunc(delay): + case <-timer.C: } } } diff --git a/runtime/runtime.go b/runtime/runtime.go index 1557d16..90d82b4 100644 --- a/runtime/runtime.go +++ b/runtime/runtime.go @@ -29,6 +29,7 @@ import ( "os" "path/filepath" "strings" + "time" "github.com/GrayCodeAI/eyrie/catalog" "github.com/GrayCodeAI/eyrie/client" @@ -213,12 +214,16 @@ func (r *Runtime) CredentialTargets() []CredentialTarget { continue } seen[env] = true + // Bound the keyring lookup so a stuck keychain doesn't + // block the entire CredentialTargets enumeration. + probeCtx, probeCancel := context.WithTimeout(context.Background(), 5*time.Second) out = append(out, CredentialTarget{ ProviderID: depID, DeploymentID: id, EnvVar: env, - Set: credentials.HasSecret(context.Background(), env), + Set: credentials.HasSecret(probeCtx, env), }) + probeCancel() } } return out diff --git a/storage/budgets.go b/storage/budgets.go index f634fd5..6c9f30b 100644 --- a/storage/budgets.go +++ b/storage/budgets.go @@ -5,6 +5,7 @@ import ( "database/sql" "errors" "fmt" + "os" "time" _ "modernc.org/sqlite" @@ -41,7 +42,8 @@ type BudgetStore struct { } // OpenBudgetStore opens (or creates) a SQLite database at path and ensures the -// budget schema exists. +// budget schema exists. The database file is set to 0o600 because the +// virtual_key_secrets table stores provider API keys. func OpenBudgetStore(path string) (*BudgetStore, error) { db, err := sql.Open("sqlite", path+"?_pragma=journal_mode(wal)&_pragma=busy_timeout(5000)&_pragma=foreign_keys(on)") if err != nil { @@ -53,6 +55,16 @@ func OpenBudgetStore(path string) (*BudgetStore, error) { _ = db.Close() return nil, err } + // Restrict file permissions: the database stores plaintext provider API + // keys in virtual_key_secrets. The file is created by the SQLite driver + // with the process umask, which may be 0o644 on some systems. + // In WAL mode SQLite also creates -wal and -shm sidecar + // files; the WAL holds uncheckpointed pages (including plaintext keys) + // so it must be tightened too. Errors are ignored for sidecars that + // don't exist yet. + _ = os.Chmod(path, 0o600) + _ = os.Chmod(path+"-wal", 0o600) + _ = os.Chmod(path+"-shm", 0o600) return s, nil } diff --git a/storage/sqlite.go b/storage/sqlite.go index 218b5f9..fe2faaa 100644 --- a/storage/sqlite.go +++ b/storage/sqlite.go @@ -99,7 +99,9 @@ func (s *SQLiteStore) GetNode(ctx context.Context, id string) (*Node, error) { func (s *SQLiteStore) GetNodeByPrefix(ctx context.Context, prefix string) (*Node, error) { // Escape SQL LIKE wildcards in the prefix to prevent false matches. - escaped := strings.NewReplacer("%", "\\%", "_", "\\_").Replace(prefix) + // Backslash must be escaped first (before % and _ are escaped to \% + // and \_) so the escape character itself is not misinterpreted. + escaped := strings.NewReplacer("\\", "\\\\", "%", "\\%", "_", "\\_").Replace(prefix) return s.scanNode(s.db.QueryRowContext(ctx, `SELECT id, parent_id, root_id, sequence, node_type, content, provider, model, tokens_in, tokens_out, tokens_cache_read, tokens_cache_creation, tokens_reasoning, latency_ms, stop_reason, output_group_id, status, title, system_prompt, metadata, created_at FROM nodes WHERE id LIKE ? ESCAPE '\' LIMIT 1`, escaped+"%")) }