Skip to content

Commit 0e12707

Browse files
committed
refactor: comprehensive hardening — txStore, BFS directionality, AccessTracker, rate limiting, input validation, error handling
This commit addresses all issues identified in the deep codebase review: - txStore: implement all 30+ stub methods with real transaction-backed SQL - Engine: add sync.Mutex to serialize writes, prevent SQLite busy-lock races - BFS: respect edge directionality (acyclic=directed, cyclic=bidirectional) - SQLite: add _busy_timeout, MaxOpenConns(1), batch chunking at 900 vars - AccessTracker: replace per-recall UPDATE with INSERT + batched flush - GetEdgesBetween: eliminate N+1 edge queries in ExtractSubgraph - 6 new DB indexes for faster filtering and traversal - Domain errors: ErrNodeNotFound, ErrEdgeNotFound, ErrDuplicateNode, etc. - REST/MCP: input validation for node types, edge types, depth caps, body limits - Rate limiter: per-IP token bucket (30 burst / 10 sustained) - escapeFTS5: prevent query injection into FTS5 MATCH - Privacy filter: entropy detection for secret leakage - gRPC: add GracefulStop shutdown method - REST: add Shutdown with context - DualStream: detached 1-minute context for slow-path survival - SSE: panic recovery around json.Marshal - PendingNodes: cap at 1000 nodes, add MinConfidence filter - All silent error ignores replaced with proper handling in production code - Comprehensive tests: graph_test (7), rest_test (8), sqlite_test (18), engine_test (15), integration_test (concurrent access)
1 parent 0dfb895 commit 0e12707

28 files changed

Lines changed: 2227 additions & 223 deletions

cmd/yaad/core.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/GrayCodeAI/yaad/internal/engine"
1212
"github.com/GrayCodeAI/yaad/internal/storage"
1313
"github.com/GrayCodeAI/yaad/internal/utils"
14+
"github.com/GrayCodeAI/yaad/internal/version"
1415
)
1516

1617
var initCmd = &cobra.Command{
@@ -178,7 +179,7 @@ var statusCmd = &cobra.Command{
178179
fmt.Fprintf(os.Stderr, "error: %v\n", err)
179180
os.Exit(1)
180181
}
181-
fmt.Printf("yaad v%s\n", version)
182+
fmt.Printf("yaad v%s\n", version.String())
182183
fmt.Printf(" Nodes: %d\n", st.Nodes)
183184
fmt.Printf(" Edges: %d\n", st.Edges)
184185
fmt.Printf(" Sessions: %d\n", st.Sessions)

cmd/yaad/helpers.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@ func dbPath() string {
2121
// openEngine opens the yaad database and returns an engine.
2222
// Exits on error — CLI commands should not continue without a DB.
2323
func openEngine() *engine.Engine {
24-
if err := os.MkdirAll(filepath.Dir(dbPath()), 0755); err != nil {
25-
fmt.Fprintf(os.Stderr, "error creating .yaad/: %v\n", err)
24+
path := dbPath()
25+
if _, err := os.Stat(path); os.IsNotExist(err) {
26+
fmt.Fprintf(os.Stderr, "error: no yaad project found in %s\n", filepath.Dir(path))
27+
fmt.Fprintf(os.Stderr, "Run 'yaad init' to initialize a project.\n")
2628
os.Exit(1)
2729
}
28-
store, err := storage.NewStore(dbPath())
30+
store, err := storage.NewStore(path)
2931
if err != nil {
3032
fmt.Fprintf(os.Stderr, "error opening database: %v\n", err)
3133
os.Exit(1)

cmd/yaad/main.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@ import (
77
"github.com/spf13/cobra"
88
)
99

10-
var version = "dev" // set by -ldflags="-X main.version=v0.1.0" at build time
11-
1210
func main() {
1311
if err := rootCmd.Execute(); err != nil {
1412
fmt.Fprintln(os.Stderr, err)

cmd/yaad/server.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/GrayCodeAI/yaad/internal/server"
1212
"github.com/GrayCodeAI/yaad/internal/storage"
1313
"github.com/GrayCodeAI/yaad/internal/utils"
14+
"github.com/GrayCodeAI/yaad/internal/version"
1415
)
1516

1617
var serveCmd = &cobra.Command{
@@ -20,7 +21,7 @@ var serveCmd = &cobra.Command{
2021
eng := openEngine()
2122
defer eng.Store().Close()
2223
addr, _ := cmd.Flags().GetString("addr")
23-
fmt.Printf("yaad v%s — REST API on %s\n", version, addr)
24+
fmt.Printf("yaad v%s — REST API on %s\n", version.String(), addr)
2425
rest := server.NewRESTServer(eng, addr)
2526
if err := rest.ListenAndServe(); err != nil {
2627
fmt.Fprintf(os.Stderr, "error: %v\n", err)

integration_test.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"os"
1111
"path/filepath"
1212
"strings"
13+
"sync"
1314
"testing"
1415
"time"
1516

@@ -1157,3 +1158,63 @@ func TestRESTAPI(t *testing.T) {
11571158
}
11581159
resp.Body.Close()
11591160
}
1161+
1162+
// TestConcurrentSQLiteAccess verifies that concurrent Remember and Recall operations
1163+
// against the real SQLite backend do not race or corrupt data.
1164+
func TestConcurrentSQLiteAccess(t *testing.T) {
1165+
eng, cleanup := setup(t)
1166+
defer cleanup()
1167+
1168+
var wg sync.WaitGroup
1169+
numWriters := 5
1170+
numReaders := 5
1171+
opsPerGoroutine := 10
1172+
1173+
// Writers
1174+
for i := 0; i < numWriters; i++ {
1175+
wg.Add(1)
1176+
go func(idx int) {
1177+
defer wg.Done()
1178+
for j := 0; j < opsPerGoroutine; j++ {
1179+
_, err := eng.Remember(context.Background(), engine.RememberInput{
1180+
Type: "convention",
1181+
Content: fmt.Sprintf("writer-%d-op-%d", idx, j),
1182+
Scope: "project",
1183+
Project: "concurrent-test",
1184+
})
1185+
if err != nil {
1186+
t.Errorf("writer %d op %d failed: %v", idx, j, err)
1187+
}
1188+
}
1189+
}(i)
1190+
}
1191+
1192+
// Readers
1193+
for i := 0; i < numReaders; i++ {
1194+
wg.Add(1)
1195+
go func(idx int) {
1196+
defer wg.Done()
1197+
for j := 0; j < opsPerGoroutine; j++ {
1198+
_, err := eng.Recall(context.Background(), engine.RecallOpts{
1199+
Query: "writer",
1200+
Project: "concurrent-test",
1201+
Limit: 10,
1202+
})
1203+
if err != nil {
1204+
t.Errorf("reader %d op %d failed: %v", idx, j, err)
1205+
}
1206+
}
1207+
}(i)
1208+
}
1209+
1210+
wg.Wait()
1211+
1212+
st, err := eng.Status(context.Background(), "concurrent-test")
1213+
if err != nil {
1214+
t.Fatalf("status failed: %v", err)
1215+
}
1216+
expectedNodes := numWriters * opsPerGoroutine
1217+
if st.Nodes < expectedNodes {
1218+
t.Errorf("expected at least %d nodes, got %d", expectedNodes, st.Nodes)
1219+
}
1220+
}

internal/config/config.go

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package config
22

33
import (
4+
"fmt"
45
"os"
56
"path/filepath"
67

@@ -85,19 +86,30 @@ func Load(projectDir string) (*Config, error) {
8586
// Global config
8687
home, err := os.UserHomeDir()
8788
if err == nil {
88-
loadFile(filepath.Join(home, ".yaad", "config.toml"), cfg)
89+
if err := loadFile(filepath.Join(home, ".yaad", "config.toml"), cfg); err != nil {
90+
return nil, err
91+
}
8992
}
9093

9194
// Project config (overrides global)
9295
if projectDir != "" {
93-
loadFile(filepath.Join(projectDir, ".yaad", "config.toml"), cfg)
96+
if err := loadFile(filepath.Join(projectDir, ".yaad", "config.toml"), cfg); err != nil {
97+
return nil, err
98+
}
9499
}
95100

96101
return cfg, nil
97102
}
98103

99-
func loadFile(path string, cfg *Config) {
100-
if _, err := os.Stat(path); err == nil {
101-
toml.DecodeFile(path, cfg)
104+
func loadFile(path string, cfg *Config) error {
105+
if _, err := os.Stat(path); err != nil {
106+
if os.IsNotExist(err) {
107+
return nil // no file to load, not an error
108+
}
109+
return err
102110
}
111+
if _, err := toml.DecodeFile(path, cfg); err != nil {
112+
return fmt.Errorf("invalid config %s: %w", path, err)
113+
}
114+
return nil
103115
}

internal/engine/access_tracker.go

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
package engine
2+
3+
import (
4+
"context"
5+
"log/slog"
6+
"sync"
7+
"time"
8+
9+
"github.com/GrayCodeAI/yaad/internal/storage"
10+
)
11+
12+
const maxAccessBuffer = 10000
13+
14+
// AccessTracker batches node access events to reduce SQLite UPDATE churn.
15+
// Instead of updating nodes.access_count on every recall (which causes write
16+
// contention under concurrent load), it INSERTs lightweight rows into
17+
// access_log and periodically flushes them in a single aggregated UPDATE.
18+
type AccessTracker struct {
19+
store storage.Storage
20+
buffer []string // node IDs pending flush
21+
mu sync.Mutex
22+
flushTick *time.Ticker
23+
stopCh chan struct{}
24+
}
25+
26+
// NewAccessTracker creates a tracker that flushes every interval.
27+
func NewAccessTracker(store storage.Storage, interval time.Duration) *AccessTracker {
28+
at := &AccessTracker{
29+
store: store,
30+
buffer: make([]string, 0, 128),
31+
flushTick: time.NewTicker(interval),
32+
stopCh: make(chan struct{}),
33+
}
34+
go at.loop()
35+
return at
36+
}
37+
38+
// Log records an access for the given node ID (best-effort, non-blocking).
39+
func (at *AccessTracker) Log(ctx context.Context, nodeID string) {
40+
// Try SQLite INSERT directly (fast, append-only, minimal contention)
41+
if err := at.store.LogAccess(ctx, nodeID); err == nil {
42+
return
43+
}
44+
// Fall back to in-memory buffer if DB is temporarily unavailable
45+
at.mu.Lock()
46+
if len(at.buffer) < maxAccessBuffer {
47+
at.buffer = append(at.buffer, nodeID)
48+
} else {
49+
slog.Warn("access_tracker: buffer full, dropping access log", "node_id", nodeID)
50+
}
51+
at.mu.Unlock()
52+
}
53+
54+
// Flush immediately applies all pending access counts to nodes.
55+
func (at *AccessTracker) Flush(ctx context.Context) {
56+
// Flush any buffered in-memory items first
57+
at.mu.Lock()
58+
buf := make([]string, len(at.buffer))
59+
copy(buf, at.buffer)
60+
at.buffer = at.buffer[:0]
61+
at.mu.Unlock()
62+
63+
for _, nodeID := range buf {
64+
_ = at.store.LogAccess(ctx, nodeID)
65+
}
66+
67+
n, err := at.store.FlushAccessLog(ctx)
68+
if err != nil {
69+
slog.Warn("access_tracker: flush failed", "error", err)
70+
} else if n > 0 {
71+
slog.Debug("access_tracker: flushed", "nodes_updated", n)
72+
}
73+
}
74+
75+
// Stop halts the background flusher.
76+
func (at *AccessTracker) Stop() {
77+
at.flushTick.Stop()
78+
close(at.stopCh)
79+
}
80+
81+
func (at *AccessTracker) loop() {
82+
for {
83+
select {
84+
case <-at.flushTick.C:
85+
at.Flush(context.Background())
86+
case <-at.stopCh:
87+
return
88+
}
89+
}
90+
}

internal/engine/engine_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,26 @@ func (m *mockStorage) GetEdgesTo(ctx context.Context, nodeID string) ([]*storage
259259
return out, nil
260260
}
261261

262+
func (m *mockStorage) GetEdgesBetween(ctx context.Context, nodeIDs []string) ([]*storage.Edge, error) {
263+
if err := ctx.Err(); err != nil {
264+
return nil, err
265+
}
266+
m.mu.RLock()
267+
defer m.mu.RUnlock()
268+
idSet := make(map[string]bool, len(nodeIDs))
269+
for _, id := range nodeIDs {
270+
idSet[id] = true
271+
}
272+
var out []*storage.Edge
273+
for _, e := range m.edges {
274+
if idSet[e.FromID] && idSet[e.ToID] {
275+
cp := *e
276+
out = append(out, &cp)
277+
}
278+
}
279+
return out, nil
280+
}
281+
262282
func (m *mockStorage) CountEdges(ctx context.Context, nodeID string) (inbound int, outbound int, err error) {
263283
if err := ctx.Err(); err != nil {
264284
return 0, 0, err
@@ -443,6 +463,8 @@ func (m *mockStorage) AddReplayEvent(ctx context.Context, sessionID, data string
443463
func (m *mockStorage) GetReplayEvents(ctx context.Context, sessionID string) ([]*storage.ReplayEvent, error) {
444464
return nil, nil
445465
}
466+
func (m *mockStorage) LogAccess(ctx context.Context, nodeID string) error { return nil }
467+
func (m *mockStorage) FlushAccessLog(ctx context.Context) (int, error) { return 0, nil }
446468
func (m *mockStorage) WithTx(ctx context.Context, fn func(storage.Storage) error) error {
447469
return fn(m)
448470
}

internal/engine/feedback.go

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ func (e *Engine) Feedback(ctx context.Context, id string, action FeedbackAction,
2222
if err := ctx.Err(); err != nil {
2323
return err
2424
}
25+
e.mu.Lock()
26+
defer e.mu.Unlock()
2527
node, err := e.store.GetNode(ctx, id)
2628
if err != nil {
2729
return fmt.Errorf("node %s not found: %w", id, err)
@@ -38,7 +40,9 @@ func (e *Engine) Feedback(ctx context.Context, id string, action FeedbackAction,
3840
if newContent == "" {
3941
return fmt.Errorf("edit requires new content")
4042
}
41-
_ = e.store.SaveVersion(ctx, node.ID, node.Content, "user", "edited via feedback")
43+
if err := e.store.SaveVersion(ctx, node.ID, node.Content, "user", "edited via feedback"); err != nil {
44+
return fmt.Errorf("save version: %w", err)
45+
}
4246
node.Content = newContent
4347
node.ContentHash = contentHash(newContent, node.Scope, node.Project)
4448
node.Version++
@@ -47,7 +51,9 @@ func (e *Engine) Feedback(ctx context.Context, id string, action FeedbackAction,
4751
return e.store.UpdateNode(ctx, node)
4852

4953
case FeedbackDiscard:
50-
_ = e.store.SaveVersion(ctx, node.ID, node.Content, "user", "discarded via feedback")
54+
if err := e.store.SaveVersion(ctx, node.ID, node.Content, "user", "discarded via feedback"); err != nil {
55+
return fmt.Errorf("save version: %w", err)
56+
}
5157
node.Confidence = 0
5258
return e.store.UpdateNode(ctx, node)
5359

@@ -61,6 +67,8 @@ func (e *Engine) Rollback(ctx context.Context, id string, version int) error {
6167
if err := ctx.Err(); err != nil {
6268
return err
6369
}
70+
e.mu.Lock()
71+
defer e.mu.Unlock()
6472
versions, err := e.store.GetVersions(ctx, id)
6573
if err != nil {
6674
return err
@@ -71,7 +79,9 @@ func (e *Engine) Rollback(ctx context.Context, id string, version int) error {
7179
if err != nil {
7280
return err
7381
}
74-
_ = e.store.SaveVersion(ctx, node.ID, node.Content, "system", fmt.Sprintf("rollback to v%d", version))
82+
if err := e.store.SaveVersion(ctx, node.ID, node.Content, "system", fmt.Sprintf("rollback to v%d", version)); err != nil {
83+
return fmt.Errorf("save version: %w", err)
84+
}
7585
node.Content = v.Content
7686
node.Version++
7787
node.UpdatedAt = time.Now()
@@ -82,18 +92,22 @@ func (e *Engine) Rollback(ctx context.Context, id string, version int) error {
8292
}
8393

8494
// PendingNodes returns low-confidence nodes that may need review.
95+
// Limited to 1000 nodes to prevent unbounded memory use on large graphs.
8596
func (e *Engine) PendingNodes(ctx context.Context, project string, threshold float64) ([]*storage.Node, error) {
8697
if err := ctx.Err(); err != nil {
8798
return nil, err
8899
}
89-
nodes, err := e.store.ListNodes(ctx, storage.NodeFilter{Project: project})
100+
nodes, err := e.store.ListNodes(ctx, storage.NodeFilter{Project: project, MinConfidence: 0.01})
90101
if err != nil {
91102
return nil, err
92103
}
93104
var pending []*storage.Node
94105
for _, n := range nodes {
95106
if n.Confidence > 0 && n.Confidence < threshold {
96107
pending = append(pending, n)
108+
if len(pending) >= 1000 {
109+
break
110+
}
97111
}
98112
}
99113
return pending, nil

0 commit comments

Comments
 (0)