Skip to content

Commit 5e7f7fc

Browse files
committed
fix: add backpressure logging, query timeouts, configurable privacy, and response limits
- Add rate-limited warning logging when slow-path queue drops jobs - Add path canonicalization and directory validation in git watcher - Return errors from critical hook operations instead of silent logging - Add configurable FilterLevel (Strict/Moderate/Minimal) for privacy - Trigger async flush when access tracker buffer full (instead of dropping) - Add query timeout wrapper (30s default) in SQLite storage - Add response size limit (5MB) with 413 on overflow in REST API
1 parent 9aaf4ba commit 5e7f7fc

9 files changed

Lines changed: 200 additions & 24 deletions

File tree

engine/access_tracker.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,19 @@ func (at *AccessTracker) Log(ctx context.Context, nodeID string) {
4646
at.mu.Lock()
4747
if len(at.buffer) < maxAccessBuffer {
4848
at.buffer = append(at.buffer, nodeID)
49-
} else {
50-
slog.Warn("access_tracker: buffer full, dropping access log", "node_id", nodeID)
49+
at.mu.Unlock()
50+
return
5151
}
5252
at.mu.Unlock()
53+
54+
// Buffer full — trigger an async flush instead of dropping
55+
slog.Warn("access_tracker: buffer full, triggering forced flush", "node_id", nodeID)
56+
go at.Flush(context.Background())
57+
58+
// After triggering flush, still try to buffer this item
59+
at.mu.Lock()
60+
at.buffer = append(at.buffer, nodeID)
61+
at.mu.Unlock()
5362
}
5463

5564
// Flush immediately applies all pending access counts to nodes.

git/watcher.go

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ package git
33
import (
44
"context"
55
"fmt"
6+
"os"
67
"os/exec"
8+
"path/filepath"
79
"strings"
810
"time"
911

@@ -26,8 +28,22 @@ type Watcher struct {
2628
}
2729

2830
// New creates a git Watcher for the given project directory.
29-
func New(store storage.Storage, g graph.Graph, dir string) *Watcher {
30-
return &Watcher{store: store, graph: g, dir: dir}
31+
// Returns an error if the directory path is invalid or does not exist.
32+
func New(store storage.Storage, g graph.Graph, dir string) (*Watcher, error) {
33+
absDir, err := filepath.Abs(dir)
34+
if err != nil {
35+
return nil, fmt.Errorf("git watcher: invalid path %q: %w", dir, err)
36+
}
37+
absDir = filepath.Clean(absDir)
38+
39+
info, err := os.Stat(absDir)
40+
if err != nil {
41+
return nil, fmt.Errorf("git watcher: path %q does not exist: %w", absDir, err)
42+
}
43+
if !info.IsDir() {
44+
return nil, fmt.Errorf("git watcher: path %q is not a directory", absDir)
45+
}
46+
return &Watcher{store: store, graph: g, dir: absDir}, nil
3147
}
3248

3349
// StalesSince returns stale reports for files changed since the given time.

hooks/runner.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,14 @@ func (r *Runner) SessionStart(ctx context.Context, in *HookInput) error {
7474
}
7575

7676
// Write session ID to a temp file for other hooks to pick up
77-
if err := os.WriteFile(sessionFile(r.project), []byte(sessionID), 0600); err != nil {
78-
// Best-effort: log but don't fail the hook
77+
sf := sessionFile(r.project)
78+
if err := os.MkdirAll(filepath.Dir(sf), 0755); err != nil {
79+
fmt.Fprintf(os.Stderr, "yaad: warning: could not create session dir: %v\n", err)
80+
return fmt.Errorf("create session dir: %w", err)
81+
}
82+
if err := os.WriteFile(sf, []byte(sessionID), 0600); err != nil {
7983
fmt.Fprintf(os.Stderr, "yaad: warning: could not write session file: %v\n", err)
84+
return fmt.Errorf("write session file: %w", err)
8085
}
8186

8287
// Auto-decay: keep graph lean without manual intervention
@@ -143,6 +148,7 @@ func (r *Runner) SessionEnd(ctx context.Context, in *HookInput) error {
143148
Agent: in.Agent,
144149
}); err != nil {
145150
fmt.Fprintf(os.Stderr, "yaad: warning: could not store session summary: %v\n", err)
151+
return fmt.Errorf("store session summary: %w", err)
146152
}
147153
}
148154

@@ -152,6 +158,9 @@ func (r *Runner) SessionEnd(ctx context.Context, in *HookInput) error {
152158
// Clean up session file (best-effort)
153159
if rmErr := os.Remove(sessionFile(r.project)); rmErr != nil {
154160
fmt.Fprintf(os.Stderr, "yaad: warning: could not remove session file: %v\n", rmErr)
161+
if err == nil {
162+
return fmt.Errorf("remove session file: %w", rmErr)
163+
}
155164
}
156165
return err
157166
}

ingest/dual_stream.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"context"
1313
"log/slog"
1414
"sync"
15+
"sync/atomic"
1516
"time"
1617

1718
"github.com/google/uuid"
@@ -39,6 +40,9 @@ type DualStream struct {
3940
once sync.Once
4041
lastNode map[string]string // project → last node ID (temporal backbone)
4142
mu sync.Mutex
43+
44+
dropped atomic.Int64
45+
lastDropLog atomic.Int64 // unix timestamp of last drop warning
4246
}
4347

4448
// New creates a DualStream ingestion manager.
@@ -90,6 +94,15 @@ func (ds *DualStream) Remember(ctx context.Context, in engine.RememberInput) (*s
9094
}:
9195
default:
9296
// Queue full — skip slow path for this node (graceful degradation)
97+
total := ds.dropped.Add(1)
98+
now := time.Now().Unix()
99+
last := ds.lastDropLog.Load()
100+
if now-last >= 10 {
101+
if ds.lastDropLog.CompareAndSwap(last, now) {
102+
slog.Warn("slow-path queue full, dropping jobs",
103+
"dropped_total", total, "node_id", node.ID)
104+
}
105+
}
93106
}
94107

95108
return node, nil

internal/server/mcp.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,10 @@ func (s *MCPServer) handleStale(ctx context.Context, req mcp.CallToolRequest) (*
364364
if project == "" {
365365
project, _ = os.Getwd()
366366
}
367-
watcher := gitwatch.New(s.eng.Store(), s.eng.Graph(), project)
367+
watcher, err := gitwatch.New(s.eng.Store(), s.eng.Graph(), project)
368+
if err != nil {
369+
return nil, err
370+
}
368371
since := time.Now().Add(-7 * 24 * time.Hour)
369372
reports, err := watcher.StalesSince(ctx, since)
370373
if err != nil {

internal/server/rest.go

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ import (
2525
)
2626

2727
const (
28-
maxRequestBodySize = 1 << 20 // 1 MB
28+
maxRequestBodySize = 1 << 20 // 1 MB
29+
maxResponseSize = 5 * (1 << 20) // 5 MB
2930
maxGraphDepth = 5
3031
)
3132

@@ -201,7 +202,7 @@ func (s *RESTServer) handleRecall(w http.ResponseWriter, r *http.Request) {
201202
httpErr(w, err, 500)
202203
return
203204
}
204-
httpJSON(w, result, 200)
205+
httpJSONCapped(w, result, 200)
205206
}
206207

207208
func (s *RESTServer) handleContext(w http.ResponseWriter, r *http.Request) {
@@ -211,7 +212,7 @@ func (s *RESTServer) handleContext(w http.ResponseWriter, r *http.Request) {
211212
httpErr(w, err, 500)
212213
return
213214
}
214-
httpJSON(w, result, 200)
215+
httpJSONCapped(w, result, 200)
215216
}
216217

217218
func (s *RESTServer) handleLink(w http.ResponseWriter, r *http.Request) {
@@ -275,7 +276,7 @@ func (s *RESTServer) handleSubgraph(w http.ResponseWriter, r *http.Request) {
275276
httpErr(w, err, 500)
276277
return
277278
}
278-
httpJSON(w, sg, 200)
279+
httpJSONCapped(w, sg, 200)
279280
}
280281

281282
func (s *RESTServer) handleImpact(w http.ResponseWriter, r *http.Request) {
@@ -296,7 +297,7 @@ func (s *RESTServer) handleImpact(w http.ResponseWriter, r *http.Request) {
296297
nodes = append(nodes, n)
297298
}
298299
}
299-
httpJSON(w, nodes, 200)
300+
httpJSONCapped(w, nodes, 200)
300301
}
301302

302303
func (s *RESTServer) handleForget(w http.ResponseWriter, r *http.Request) {
@@ -475,7 +476,11 @@ func (s *RESTServer) handleStale(w http.ResponseWriter, r *http.Request) {
475476
httpJSON(w, map[string]string{"status": "no project directory configured"}, 200)
476477
return
477478
}
478-
watcher := gitwatch.New(s.eng.Store(), s.eng.Graph(), s.projectDir)
479+
watcher, err := gitwatch.New(s.eng.Store(), s.eng.Graph(), s.projectDir)
480+
if err != nil {
481+
httpErr(w, err, 500)
482+
return
483+
}
479484
since := time.Now().Add(-7 * 24 * time.Hour) // last 7 days
480485
reports, err := watcher.StalesSince(r.Context(), since)
481486
if err != nil {
@@ -537,7 +542,7 @@ func (s *RESTServer) handleHybridRecall(w http.ResponseWriter, r *http.Request)
537542
return
538543
}
539544
reranked := engine.Rerank(r.Context(), scored, s.eng.Store())
540-
httpJSON(w, reranked, 200)
545+
httpJSONCapped(w, reranked, 200)
541546
}
542547

543548
func (s *RESTServer) handleProactive(w http.ResponseWriter, r *http.Request) {
@@ -759,6 +764,23 @@ func httpJSON(w http.ResponseWriter, v any, code int) {
759764
json.NewEncoder(w).Encode(v)
760765
}
761766

767+
// httpJSONCapped encodes the response as JSON but returns a 413 error if the
768+
// serialized payload exceeds maxResponseSize.
769+
func httpJSONCapped(w http.ResponseWriter, v any, code int) {
770+
data, err := json.Marshal(v)
771+
if err != nil {
772+
httpErr(w, fmt.Errorf("marshal response: %w", err), 500)
773+
return
774+
}
775+
if len(data) > maxResponseSize {
776+
httpErr(w, fmt.Errorf("response size %d bytes exceeds limit of %d bytes", len(data), maxResponseSize), 413)
777+
return
778+
}
779+
w.Header().Set("Content-Type", "application/json")
780+
w.WriteHeader(code)
781+
w.Write(data)
782+
}
783+
762784
func httpErr(w http.ResponseWriter, err error, code int) {
763785
httpJSON(w, map[string]string{"error": err.Error()}, code)
764786
}

privacy/filter.go

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,23 @@ import (
66
"strings"
77
)
88

9-
var patterns = []*regexp.Regexp{
9+
// FilterLevel controls how aggressively content is filtered.
10+
type FilterLevel int
11+
12+
const (
13+
// Strict strips everything (emails, IPs, phones, secrets).
14+
Strict FilterLevel = iota
15+
// Moderate keeps infrastructure IPs (10.x, 192.168.x) and work-related emails.
16+
Moderate
17+
// Minimal only strips high-entropy secrets and explicit API keys.
18+
Minimal
19+
)
20+
21+
// DefaultFilterLevel is used by Filter() when no level is specified.
22+
var DefaultFilterLevel = Moderate
23+
24+
// secretPatterns are always stripped (all levels including Minimal).
25+
var secretPatterns = []*regexp.Regexp{
1026
// API keys
1127
regexp.MustCompile(`sk-[a-zA-Z0-9]{20,}`), // OpenAI
1228
regexp.MustCompile(`AKIA[A-Z0-9]{16}`), // AWS Access Key
@@ -26,23 +42,55 @@ var patterns = []*regexp.Regexp{
2642
regexp.MustCompile(`-----BEGIN\s+\w+\s+PRIVATE\s+KEY-----[\s\S]*?-----END\s+\w+\s+PRIVATE\s+KEY-----`),
2743
// Connection strings with passwords
2844
regexp.MustCompile(`(?i)(postgres|mysql|mongodb)://[^\s"]+:[^\s"]+@[^\s"]+`),
29-
// PII: email addresses
30-
regexp.MustCompile(`[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}`),
45+
}
46+
47+
// piiPatterns are stripped by Strict and Moderate (but not Minimal).
48+
var piiPatterns = []*regexp.Regexp{
3149
// PII: US phone numbers
3250
regexp.MustCompile(`\b(?:\+?1[-.\s]?)?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}\b`),
3351
// PII: US Social Security Numbers
3452
regexp.MustCompile(`\b\d{3}-\d{2}-\d{4}\b`),
35-
// PII: IPv4 addresses (non-localhost, non-private documentation ranges)
36-
regexp.MustCompile(`\b(?:(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\b`),
3753
// PII: credit card numbers (basic Luhn-eligible patterns)
3854
regexp.MustCompile(`\b(?:\d[ -]*?){13,19}\b`),
3955
}
4056

41-
// Filter replaces secrets in content with [REDACTED].
57+
// strictOnlyPatterns are only stripped in Strict mode.
58+
var strictOnlyPatterns = []*regexp.Regexp{
59+
// PII: email addresses (Moderate keeps work emails)
60+
regexp.MustCompile(`[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}`),
61+
// PII: IPv4 addresses (Moderate keeps infrastructure IPs like 10.x, 192.168.x)
62+
regexp.MustCompile(`\b(?:(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\b`),
63+
}
64+
65+
// patterns is kept for backward compatibility — contains all patterns (Strict behavior).
66+
var patterns = append(append(append([]*regexp.Regexp{}, secretPatterns...), piiPatterns...), strictOnlyPatterns...)
67+
68+
// Filter replaces secrets in content with [REDACTED] using DefaultFilterLevel.
4269
func Filter(content string) string {
43-
for _, p := range patterns {
70+
return FilterWithLevel(content, DefaultFilterLevel)
71+
}
72+
73+
// FilterWithLevel replaces secrets in content with [REDACTED] at the specified level.
74+
func FilterWithLevel(content string, level FilterLevel) string {
75+
// All levels strip explicit secrets
76+
for _, p := range secretPatterns {
4477
content = p.ReplaceAllString(content, "[REDACTED]")
4578
}
79+
80+
// Moderate and Strict also strip PII patterns (phones, SSNs, credit cards)
81+
if level <= Moderate {
82+
for _, p := range piiPatterns {
83+
content = p.ReplaceAllString(content, "[REDACTED]")
84+
}
85+
}
86+
87+
// Strict strips everything including emails and all IPs
88+
if level == Strict {
89+
for _, p := range strictOnlyPatterns {
90+
content = p.ReplaceAllString(content, "[REDACTED]")
91+
}
92+
}
93+
4694
// Catch high-entropy tokens that regexes might miss.
4795
// Only target tokens that look like standalone secrets (no JSON, no code).
4896
for _, word := range strings.Fields(content) {

privacy/filter_test.go

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ func TestFilter_PII(t *testing.T) {
4949
name string
5050
input string
5151
}{
52-
{"email", "contact user@example.com for help"},
5352
{"SSN", "ssn is 123-45-6789"},
5453
{"phone", "call 555-123-4567"},
5554
}
@@ -63,6 +62,48 @@ func TestFilter_PII(t *testing.T) {
6362
}
6463
}
6564

65+
func TestFilter_Strict(t *testing.T) {
66+
// Strict mode strips emails and all IPs
67+
result := FilterWithLevel("contact user@example.com for help", Strict)
68+
if !strings.Contains(result, "[REDACTED]") {
69+
t.Errorf("expected email redaction in Strict mode, got: %s", result)
70+
}
71+
result = FilterWithLevel("server at 203.0.113.5", Strict)
72+
if !strings.Contains(result, "[REDACTED]") {
73+
t.Errorf("expected IP redaction in Strict mode, got: %s", result)
74+
}
75+
}
76+
77+
func TestFilter_Moderate(t *testing.T) {
78+
// Moderate keeps work-related emails and infrastructure IPs
79+
result := FilterWithLevel("contact user@example.com for help", Moderate)
80+
if strings.Contains(result, "[REDACTED]") {
81+
t.Errorf("Moderate should keep emails, got: %s", result)
82+
}
83+
result = FilterWithLevel("server at 10.0.0.1", Moderate)
84+
if strings.Contains(result, "[REDACTED]") {
85+
t.Errorf("Moderate should keep infrastructure IPs, got: %s", result)
86+
}
87+
// But Moderate still strips SSN and secrets
88+
result = FilterWithLevel("ssn is 123-45-6789", Moderate)
89+
if !strings.Contains(result, "[REDACTED]") {
90+
t.Errorf("Moderate should redact SSN, got: %s", result)
91+
}
92+
}
93+
94+
func TestFilter_Minimal(t *testing.T) {
95+
// Minimal only strips high-entropy secrets and explicit API keys
96+
result := FilterWithLevel("my key is sk-abcdefghijklmnopqrstuvwxyz", Minimal)
97+
if !strings.Contains(result, "[REDACTED]") {
98+
t.Errorf("Minimal should redact API keys, got: %s", result)
99+
}
100+
// Minimal does NOT strip SSN or phone
101+
result = FilterWithLevel("ssn is 123-45-6789", Minimal)
102+
if strings.Contains(result, "[REDACTED]") {
103+
t.Errorf("Minimal should not redact SSN, got: %s", result)
104+
}
105+
}
106+
66107
func TestFilter_ConnectionStrings(t *testing.T) {
67108
input := "db url: postgres://admin:s3cret@db.example.com:5432/mydb"
68109
result := Filter(input)

0 commit comments

Comments
 (0)