diff --git a/engine/engine_mock_graph_test.go b/engine/engine_mock_graph_test.go new file mode 100644 index 0000000..4442c85 --- /dev/null +++ b/engine/engine_mock_graph_test.go @@ -0,0 +1,141 @@ +//nolint:gocritic +package engine + +import ( + "context" + + "github.com/GrayCodeAI/yaad/graph" + "github.com/GrayCodeAI/yaad/intent" + "github.com/GrayCodeAI/yaad/storage" +) + +// This file is part of package engine tests. It holds mockGraph (the +// in-memory graph.Graph test double) and newTestEngine, moved verbatim out +// of engine_test.go for readability; behavior is unchanged. + +// --------------------------------------------------------------------------- +// mockGraph — in-memory implementation of graph.Graph backed by storage +// --------------------------------------------------------------------------- + +type mockGraph struct { + store storage.Storage +} + +func newMockGraph(store storage.Storage) *mockGraph { + return &mockGraph{store: store} +} + +func (g *mockGraph) AddNode(ctx context.Context, n *storage.Node) error { + return g.store.CreateNode(ctx, n) +} + +func (g *mockGraph) AddEdge(ctx context.Context, e *storage.Edge) error { + return g.store.CreateEdge(ctx, e) +} + +func (g *mockGraph) RemoveNode(ctx context.Context, id string) error { + return g.store.DeleteNode(ctx, id) +} + +func (g *mockGraph) RemoveEdge(ctx context.Context, id string) error { + return g.store.DeleteEdge(ctx, id) +} + +func (g *mockGraph) ExtractSubgraph(ctx context.Context, startID string, maxDepth int) (*graph.Subgraph, error) { + ids, err := g.BFS(ctx, startID, maxDepth) + if err != nil { + return nil, err + } + sg := &graph.Subgraph{} + for _, id := range ids { + n, err := g.store.GetNode(ctx, id) + if err == nil { + sg.Nodes = append(sg.Nodes, n) + } + } + idSet := make(map[string]bool, len(ids)) + for _, id := range ids { + idSet[id] = true + } + for _, id := range ids { + edges, _ := g.store.GetEdgesFrom(ctx, id) + for _, e := range edges { + if idSet[e.ToID] { + sg.Edges = append(sg.Edges, e) + } + } + } + return sg, nil +} + +func (g *mockGraph) BFS(ctx context.Context, startID string, maxDepth int) ([]string, error) { + _, err := g.store.GetNode(ctx, startID) + if err != nil { + return nil, nil + } + visited := map[string]bool{startID: true} + queue := []struct { + id string + depth int + }{{startID, 0}} + var result []string + result = append(result, startID) + + for len(queue) > 0 { + curr := queue[0] + queue = queue[1:] + if curr.depth >= maxDepth { + continue + } + edges, _ := g.store.GetEdgesFrom(ctx, curr.id) + edgesTo, _ := g.store.GetEdgesTo(ctx, curr.id) + allEdges := append(edges, edgesTo...) + for _, e := range allEdges { + var next string + if e.FromID == curr.id { + next = e.ToID + } else { + next = e.FromID + } + if !visited[next] { + visited[next] = true + result = append(result, next) + queue = append(queue, struct { + id string + depth int + }{next, curr.depth + 1}) + } + } + } + return result, nil +} + +func (g *mockGraph) IntentBFS(ctx context.Context, startID string, maxDepth int, queryIntent intent.Intent) ([]string, error) { + // For mock, delegate to plain BFS (intent weights are ignored) + return g.BFS(ctx, startID, maxDepth) +} + +func (g *mockGraph) Impact(ctx context.Context, filePath string, maxDepth int) ([]string, error) { + return nil, nil +} + +func (g *mockGraph) Ancestors(ctx context.Context, id string) ([]string, error) { + return nil, nil +} + +func (g *mockGraph) Descendants(ctx context.Context, id string) ([]string, error) { + return nil, nil +} + +// --------------------------------------------------------------------------- +// helper +// --------------------------------------------------------------------------- + +func newTestEngine() *Engine { + ms := newMockStorage() + return New(ms, newMockGraph(ms)) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- diff --git a/engine/engine_mock_storage_test.go b/engine/engine_mock_storage_test.go new file mode 100644 index 0000000..57ed212 --- /dev/null +++ b/engine/engine_mock_storage_test.go @@ -0,0 +1,673 @@ +//nolint:gocritic +package engine + +import ( + "context" + "strings" + "sync" + "time" + + "github.com/GrayCodeAI/yaad/storage" +) + +// This file is part of package engine tests. It holds mockStorage, the +// in-memory storage.Storage test double, moved verbatim out of +// engine_test.go for readability; behavior is unchanged. + +// --------------------------------------------------------------------------- +// mockStorage — in-memory implementation of storage.Storage +// --------------------------------------------------------------------------- + +type mockStorage struct { + mu sync.RWMutex + nodes map[string]*storage.Node + edges map[string]*storage.Edge + sessions map[string]*storage.Session + versions map[string][]*storage.NodeVersion + embeds map[string][]float32 + watches []fileWatch + metadata map[string]map[string]string +} + +type fileWatch struct { + filePath, nodeID, gitHash string +} + +func newMockStorage() *mockStorage { + return &mockStorage{ + nodes: make(map[string]*storage.Node), + edges: make(map[string]*storage.Edge), + sessions: make(map[string]*storage.Session), + versions: make(map[string][]*storage.NodeVersion), + embeds: make(map[string][]float32), + metadata: make(map[string]map[string]string), + } +} + +func (m *mockStorage) CreateNode(ctx context.Context, n *storage.Node) error { + if err := ctx.Err(); err != nil { + return err + } + m.mu.Lock() + defer m.mu.Unlock() + cp := *n + m.nodes[n.ID] = &cp + return nil +} + +func (m *mockStorage) GetNode(ctx context.Context, id string) (*storage.Node, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + m.mu.RLock() + defer m.mu.RUnlock() + if n, ok := m.nodes[id]; ok { + cp := *n + return &cp, nil + } + return nil, storage.ErrNodeNotFound +} + +func (m *mockStorage) GetNodeByKey(ctx context.Context, key, project string) (*storage.Node, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + m.mu.RLock() + defer m.mu.RUnlock() + for _, n := range m.nodes { + if n.Key == key && n.Project == project { + cp := *n + return &cp, nil + } + } + return nil, nil +} + +func (m *mockStorage) GetNodesBatch(ctx context.Context, ids []string) ([]*storage.Node, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + m.mu.RLock() + defer m.mu.RUnlock() + var out []*storage.Node + for _, id := range ids { + if n, ok := m.nodes[id]; ok { + cp := *n + out = append(out, &cp) + } + } + return out, nil +} + +func (m *mockStorage) UpdateNode(ctx context.Context, n *storage.Node) error { + if err := ctx.Err(); err != nil { + return err + } + m.mu.Lock() + defer m.mu.Unlock() + cp := *n + m.nodes[n.ID] = &cp + return nil +} + +func (m *mockStorage) UpdateNodeContent(ctx context.Context, id, newContent string) error { + if err := ctx.Err(); err != nil { + return err + } + m.mu.Lock() + defer m.mu.Unlock() + if n, ok := m.nodes[id]; ok { + n.Content = newContent + return nil + } + return storage.ErrNodeNotFound +} + +func (m *mockStorage) DeleteNode(ctx context.Context, id string) error { + if err := ctx.Err(); err != nil { + return err + } + m.mu.Lock() + defer m.mu.Unlock() + delete(m.nodes, id) + return nil +} + +func (m *mockStorage) ListNodes(ctx context.Context, f storage.NodeFilter) ([]*storage.Node, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + m.mu.RLock() + defer m.mu.RUnlock() + var out []*storage.Node + for _, n := range m.nodes { + if f.Type != "" && n.Type != f.Type { + continue + } + if f.Scope != "" && n.Scope != f.Scope { + continue + } + if f.Project != "" && n.Project != f.Project { + continue + } + if f.Tier > 0 && n.Tier != f.Tier { + continue + } + if f.MinConfidence > 0 && n.Confidence < f.MinConfidence { + continue + } + // return a copy to avoid races when caller mutates + cp := *n + out = append(out, &cp) + } + return out, nil +} + +func (m *mockStorage) SearchNodes(ctx context.Context, query string, limit int) ([]*storage.Node, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + m.mu.RLock() + defer m.mu.RUnlock() + var out []*storage.Node + words := strings.Fields(strings.ToLower(query)) + for _, n := range m.nodes { + if query == "" || matchesAnyWord(words, n.Content, n.Summary, n.Tags) { + cp := *n + out = append(out, &cp) + if limit > 0 && len(out) >= limit { + break + } + } + } + return out, nil +} + +func matchesAnyWord(words []string, fields ...string) bool { + for _, f := range fields { + lower := strings.ToLower(f) + for _, w := range words { + if strings.Contains(lower, w) { + return true + } + } + } + return false +} + +func (m *mockStorage) SearchNodeByHash(ctx context.Context, hash, scope, project string) (*storage.Node, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + m.mu.RLock() + defer m.mu.RUnlock() + for _, n := range m.nodes { + if n.ContentHash == hash && n.Scope == scope && n.Project == project { + cp := *n + return &cp, nil + } + } + return nil, nil +} + +func (m *mockStorage) GetNeighbors(ctx context.Context, nodeID string) ([]*storage.Node, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + m.mu.RLock() + defer m.mu.RUnlock() + seen := map[string]bool{} + var out []*storage.Node + for _, e := range m.edges { + var other string + if e.FromID == nodeID { + other = e.ToID + } else if e.ToID == nodeID { + other = e.FromID + } else { + continue + } + if seen[other] { + continue + } + seen[other] = true + if n, ok := m.nodes[other]; ok { + cp := *n + out = append(out, &cp) + } + } + return out, nil +} + +func (m *mockStorage) CreateEdge(ctx context.Context, e *storage.Edge) error { + if err := ctx.Err(); err != nil { + return err + } + m.mu.Lock() + defer m.mu.Unlock() + m.edges[e.ID] = e + return nil +} + +func (m *mockStorage) GetEdge(ctx context.Context, id string) (*storage.Edge, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + m.mu.RLock() + defer m.mu.RUnlock() + if e, ok := m.edges[id]; ok { + cp := *e + return &cp, nil + } + return nil, storage.ErrEdgeNotFound +} + +func (m *mockStorage) InvalidateEdge(ctx context.Context, id string) error { + return nil +} + +func (m *mockStorage) DeleteEdge(ctx context.Context, id string) error { + if err := ctx.Err(); err != nil { + return err + } + m.mu.Lock() + defer m.mu.Unlock() + delete(m.edges, id) + return nil +} + +func (m *mockStorage) GetEdgesFrom(ctx context.Context, nodeID string) ([]*storage.Edge, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + m.mu.RLock() + defer m.mu.RUnlock() + var out []*storage.Edge + for _, e := range m.edges { + if e.FromID == nodeID { + cp := *e + out = append(out, &cp) + } + } + return out, nil +} + +func (m *mockStorage) GetEdgesTo(ctx context.Context, nodeID string) ([]*storage.Edge, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + m.mu.RLock() + defer m.mu.RUnlock() + var out []*storage.Edge + for _, e := range m.edges { + if e.ToID == nodeID { + cp := *e + out = append(out, &cp) + } + } + return out, nil +} + +func (m *mockStorage) GetValidEdgesFrom(ctx context.Context, nodeID string, at time.Time) ([]*storage.Edge, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + if at.IsZero() { + at = time.Now().UTC() + } + m.mu.RLock() + defer m.mu.RUnlock() + var out []*storage.Edge + for _, e := range m.edges { + if e.FromID == nodeID && validEdgeAt(e, at) { + cp := *e + out = append(out, &cp) + } + } + return out, nil +} + +func (m *mockStorage) GetValidEdgesTo(ctx context.Context, nodeID string, at time.Time) ([]*storage.Edge, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + if at.IsZero() { + at = time.Now().UTC() + } + m.mu.RLock() + defer m.mu.RUnlock() + var out []*storage.Edge + for _, e := range m.edges { + if e.ToID == nodeID && validEdgeAt(e, at) { + cp := *e + out = append(out, &cp) + } + } + return out, nil +} + +func validEdgeAt(e *storage.Edge, at time.Time) bool { + if !e.ValidAt.IsZero() && e.ValidAt.After(at) { + return false + } + if !e.InvalidAt.IsZero() && !e.InvalidAt.After(at) { + return false + } + return true +} + +func (m *mockStorage) GetEdgesBetween(ctx context.Context, nodeIDs []string) ([]*storage.Edge, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + m.mu.RLock() + defer m.mu.RUnlock() + idSet := make(map[string]bool, len(nodeIDs)) + for _, id := range nodeIDs { + idSet[id] = true + } + var out []*storage.Edge + for _, e := range m.edges { + if idSet[e.FromID] && idSet[e.ToID] { + cp := *e + out = append(out, &cp) + } + } + return out, nil +} + +func (m *mockStorage) GetAllEdgesFor(ctx context.Context, nodeIDs []string) ([]*storage.Edge, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + m.mu.RLock() + defer m.mu.RUnlock() + idSet := make(map[string]bool, len(nodeIDs)) + for _, id := range nodeIDs { + idSet[id] = true + } + var out []*storage.Edge + for _, e := range m.edges { + if idSet[e.FromID] || idSet[e.ToID] { + cp := *e + out = append(out, &cp) + } + } + return out, nil +} + +func (m *mockStorage) CountEdges(ctx context.Context, nodeID string) (inbound int, outbound int, err error) { + if err := ctx.Err(); err != nil { + return 0, 0, err + } + m.mu.RLock() + defer m.mu.RUnlock() + for _, e := range m.edges { + if e.FromID == nodeID { + outbound++ + } + if e.ToID == nodeID { + inbound++ + } + } + return inbound, outbound, nil +} + +func (m *mockStorage) CountAllEdges(ctx context.Context) (int, error) { + if err := ctx.Err(); err != nil { + return 0, err + } + m.mu.RLock() + defer m.mu.RUnlock() + return len(m.edges), nil +} + +func (m *mockStorage) CountEdgesBatch(ctx context.Context, nodeIDs []string) (map[string][2]int, error) { + result := make(map[string][2]int, len(nodeIDs)) + for _, id := range nodeIDs { + inbound, outbound, _ := m.CountEdges(ctx, id) + result[id] = [2]int{inbound, outbound} + } + return result, nil +} + +func (m *mockStorage) CheckCycle(ctx context.Context, fromID, toID string) (bool, error) { + if err := ctx.Err(); err != nil { + return false, err + } + m.mu.RLock() + defer m.mu.RUnlock() + seen := map[string]bool{} + var walk func(id string) bool + walk = func(id string) bool { + if id == toID { + return true + } + if seen[id] { + return false + } + seen[id] = true + for _, e := range m.edges { + acyclic := e.Type == "caused_by" || e.Type == "led_to" || e.Type == "supersedes" || + e.Type == "learned_in" || e.Type == "part_of" + if e.ToID == id && acyclic { + if walk(e.FromID) { + return true + } + } + } + return false + } + return walk(fromID), nil +} + +func (m *mockStorage) CreateSession(ctx context.Context, sess *storage.Session) error { + if err := ctx.Err(); err != nil { + return err + } + m.mu.Lock() + defer m.mu.Unlock() + m.sessions[sess.ID] = sess + return nil +} + +func (m *mockStorage) EndSession(ctx context.Context, id string, summary string) error { + if err := ctx.Err(); err != nil { + return err + } + m.mu.Lock() + defer m.mu.Unlock() + if s, ok := m.sessions[id]; ok { + s.Summary = summary + s.EndedAt = time.Now() + } + return nil +} + +func (m *mockStorage) ListSessions(ctx context.Context, project string, limit int) ([]*storage.Session, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + m.mu.RLock() + defer m.mu.RUnlock() + var out []*storage.Session + for _, s := range m.sessions { + if project == "" || s.Project == project { + out = append(out, s) + } + } + return out, nil +} + +func (m *mockStorage) SaveVersion(ctx context.Context, nodeID string, content, changedBy, reason string) error { + if err := ctx.Err(); err != nil { + return err + } + m.mu.Lock() + defer m.mu.Unlock() + vers := m.versions[nodeID] + nextVer := 1 + for _, v := range vers { + if v.Version >= nextVer { + nextVer = v.Version + 1 + } + } + m.versions[nodeID] = append(vers, &storage.NodeVersion{ + NodeID: nodeID, + Content: content, + ChangedBy: changedBy, + Reason: reason, + Version: nextVer, + ChangedAt: time.Now(), + }) + return nil +} + +func (m *mockStorage) GetVersions(ctx context.Context, nodeID string) ([]*storage.NodeVersion, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + m.mu.RLock() + defer m.mu.RUnlock() + out := make([]*storage.NodeVersion, len(m.versions[nodeID])) + copy(out, m.versions[nodeID]) + return out, nil +} + +func (m *mockStorage) SaveEmbedding(ctx context.Context, nodeID, model string, vector []float32) error { + if err := ctx.Err(); err != nil { + return err + } + m.mu.Lock() + defer m.mu.Unlock() + cp := make([]float32, len(vector)) + copy(cp, vector) + m.embeds[nodeID] = cp + return nil +} + +func (m *mockStorage) DeleteEmbedding(ctx context.Context, nodeID string) error { + if err := ctx.Err(); err != nil { + return err + } + m.mu.Lock() + defer m.mu.Unlock() + delete(m.embeds, nodeID) + return nil +} + +func (m *mockStorage) AllEmbeddings(ctx context.Context, _ string) (map[string][]float32, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + m.mu.RLock() + defer m.mu.RUnlock() + out := make(map[string][]float32, len(m.embeds)) + for k, v := range m.embeds { + cp := make([]float32, len(v)) + copy(cp, v) + out[k] = cp + } + return out, nil +} + +func (m *mockStorage) GetEmbedding(ctx context.Context, nodeID string) ([]float32, string, error) { + m.mu.RLock() + defer m.mu.RUnlock() + if v, ok := m.embeds[nodeID]; ok { + cp := make([]float32, len(v)) + copy(cp, v) + return cp, "", nil + } + return nil, "", nil +} + +func (m *mockStorage) GetEmbeddingsBatch(ctx context.Context, model string, offset, limit int) (map[string][]float32, error) { + return m.AllEmbeddings(ctx, model) +} + +func (m *mockStorage) AddFileWatch(ctx context.Context, filePath, nodeID, gitHash string) error { + if err := ctx.Err(); err != nil { + return err + } + m.mu.Lock() + defer m.mu.Unlock() + m.watches = append(m.watches, fileWatch{filePath, nodeID, gitHash}) + return nil +} + +func (m *mockStorage) AddReplayEvent(ctx context.Context, sessionID, data string) error { return nil } + +func (m *mockStorage) GetReplayEvents(ctx context.Context, sessionID string) ([]*storage.ReplayEvent, error) { + return nil, nil +} +func (m *mockStorage) LogAccess(ctx context.Context, nodeID string) error { return nil } +func (m *mockStorage) FlushAccessLog(ctx context.Context) (int, error) { return 0, nil } + +func (m *mockStorage) SaveNodeMetadata(_ context.Context, nodeID string, meta map[string]string) error { + m.mu.Lock() + defer m.mu.Unlock() + m.metadata[nodeID] = meta + return nil +} + +func (m *mockStorage) LoadNodeMetadata(_ context.Context, nodeIDs []string) (map[string]map[string]string, error) { + m.mu.RLock() + defer m.mu.RUnlock() + out := make(map[string]map[string]string, len(nodeIDs)) + for _, id := range nodeIDs { + if meta, ok := m.metadata[id]; ok { + out[id] = meta + } + } + return out, nil +} + +func (m *mockStorage) NodeStats(ctx context.Context) (map[string]int, int, error) { + m.mu.RLock() + defer m.mu.RUnlock() + stats := make(map[string]int) + total := 0 + for _, n := range m.nodes { + stats[n.Type]++ + total++ + } + return stats, total, nil +} + +func (m *mockStorage) DoctorStats(ctx context.Context) (storage.DoctorStatsResult, error) { + m.mu.RLock() + defer m.mu.RUnlock() + var r storage.DoctorStatsResult + for _, n := range m.nodes { + r.TotalNodes++ + if n.Confidence < 0.2 { + r.LowConfidence++ + } + if n.Pinned { + r.Pinned++ + } + } + return r, nil +} + +func (m *mockStorage) LastUpdated(ctx context.Context) (time.Time, error) { + return time.Time{}, nil +} + +func (m *mockStorage) TopConnected(ctx context.Context, limit int) ([]string, error) { + return nil, nil +} + +func (m *mockStorage) SaveSignature(ctx context.Context, nodeID, signature string) error { + return nil +} + +func (m *mockStorage) GetAllSignatures(ctx context.Context) (map[string]string, error) { + return nil, nil +} + +func (m *mockStorage) WithTx(ctx context.Context, fn func(storage.Storage) error) error { + return fn(m) +} +func (m *mockStorage) Close() error { return nil } diff --git a/engine/engine_test.go b/engine/engine_test.go index 986ea27..85ba503 100644 --- a/engine/engine_test.go +++ b/engine/engine_test.go @@ -3,802 +3,17 @@ package engine import ( "context" - "strings" "sync" "testing" - "time" "github.com/GrayCodeAI/yaad/graph" - "github.com/GrayCodeAI/yaad/intent" "github.com/GrayCodeAI/yaad/storage" ) -// --------------------------------------------------------------------------- -// mockStorage — in-memory implementation of storage.Storage -// --------------------------------------------------------------------------- +// Note: the mockStorage and mockGraph test doubles were moved verbatim into +// engine_mock_storage_test.go and engine_mock_graph_test.go for +// readability. This file keeps the engine behavior tests. -type mockStorage struct { - mu sync.RWMutex - nodes map[string]*storage.Node - edges map[string]*storage.Edge - sessions map[string]*storage.Session - versions map[string][]*storage.NodeVersion - embeds map[string][]float32 - watches []fileWatch - metadata map[string]map[string]string -} - -type fileWatch struct { - filePath, nodeID, gitHash string -} - -func newMockStorage() *mockStorage { - return &mockStorage{ - nodes: make(map[string]*storage.Node), - edges: make(map[string]*storage.Edge), - sessions: make(map[string]*storage.Session), - versions: make(map[string][]*storage.NodeVersion), - embeds: make(map[string][]float32), - metadata: make(map[string]map[string]string), - } -} - -func (m *mockStorage) CreateNode(ctx context.Context, n *storage.Node) error { - if err := ctx.Err(); err != nil { - return err - } - m.mu.Lock() - defer m.mu.Unlock() - cp := *n - m.nodes[n.ID] = &cp - return nil -} - -func (m *mockStorage) GetNode(ctx context.Context, id string) (*storage.Node, error) { - if err := ctx.Err(); err != nil { - return nil, err - } - m.mu.RLock() - defer m.mu.RUnlock() - if n, ok := m.nodes[id]; ok { - cp := *n - return &cp, nil - } - return nil, storage.ErrNodeNotFound -} - -func (m *mockStorage) GetNodeByKey(ctx context.Context, key, project string) (*storage.Node, error) { - if err := ctx.Err(); err != nil { - return nil, err - } - m.mu.RLock() - defer m.mu.RUnlock() - for _, n := range m.nodes { - if n.Key == key && n.Project == project { - cp := *n - return &cp, nil - } - } - return nil, nil -} - -func (m *mockStorage) GetNodesBatch(ctx context.Context, ids []string) ([]*storage.Node, error) { - if err := ctx.Err(); err != nil { - return nil, err - } - m.mu.RLock() - defer m.mu.RUnlock() - var out []*storage.Node - for _, id := range ids { - if n, ok := m.nodes[id]; ok { - cp := *n - out = append(out, &cp) - } - } - return out, nil -} - -func (m *mockStorage) UpdateNode(ctx context.Context, n *storage.Node) error { - if err := ctx.Err(); err != nil { - return err - } - m.mu.Lock() - defer m.mu.Unlock() - cp := *n - m.nodes[n.ID] = &cp - return nil -} - -func (m *mockStorage) UpdateNodeContent(ctx context.Context, id, newContent string) error { - if err := ctx.Err(); err != nil { - return err - } - m.mu.Lock() - defer m.mu.Unlock() - if n, ok := m.nodes[id]; ok { - n.Content = newContent - return nil - } - return storage.ErrNodeNotFound -} - -func (m *mockStorage) DeleteNode(ctx context.Context, id string) error { - if err := ctx.Err(); err != nil { - return err - } - m.mu.Lock() - defer m.mu.Unlock() - delete(m.nodes, id) - return nil -} - -func (m *mockStorage) ListNodes(ctx context.Context, f storage.NodeFilter) ([]*storage.Node, error) { - if err := ctx.Err(); err != nil { - return nil, err - } - m.mu.RLock() - defer m.mu.RUnlock() - var out []*storage.Node - for _, n := range m.nodes { - if f.Type != "" && n.Type != f.Type { - continue - } - if f.Scope != "" && n.Scope != f.Scope { - continue - } - if f.Project != "" && n.Project != f.Project { - continue - } - if f.Tier > 0 && n.Tier != f.Tier { - continue - } - if f.MinConfidence > 0 && n.Confidence < f.MinConfidence { - continue - } - // return a copy to avoid races when caller mutates - cp := *n - out = append(out, &cp) - } - return out, nil -} - -func (m *mockStorage) SearchNodes(ctx context.Context, query string, limit int) ([]*storage.Node, error) { - if err := ctx.Err(); err != nil { - return nil, err - } - m.mu.RLock() - defer m.mu.RUnlock() - var out []*storage.Node - words := strings.Fields(strings.ToLower(query)) - for _, n := range m.nodes { - if query == "" || matchesAnyWord(words, n.Content, n.Summary, n.Tags) { - cp := *n - out = append(out, &cp) - if limit > 0 && len(out) >= limit { - break - } - } - } - return out, nil -} - -func matchesAnyWord(words []string, fields ...string) bool { - for _, f := range fields { - lower := strings.ToLower(f) - for _, w := range words { - if strings.Contains(lower, w) { - return true - } - } - } - return false -} - -func (m *mockStorage) SearchNodeByHash(ctx context.Context, hash, scope, project string) (*storage.Node, error) { - if err := ctx.Err(); err != nil { - return nil, err - } - m.mu.RLock() - defer m.mu.RUnlock() - for _, n := range m.nodes { - if n.ContentHash == hash && n.Scope == scope && n.Project == project { - cp := *n - return &cp, nil - } - } - return nil, nil -} - -func (m *mockStorage) GetNeighbors(ctx context.Context, nodeID string) ([]*storage.Node, error) { - if err := ctx.Err(); err != nil { - return nil, err - } - m.mu.RLock() - defer m.mu.RUnlock() - seen := map[string]bool{} - var out []*storage.Node - for _, e := range m.edges { - var other string - if e.FromID == nodeID { - other = e.ToID - } else if e.ToID == nodeID { - other = e.FromID - } else { - continue - } - if seen[other] { - continue - } - seen[other] = true - if n, ok := m.nodes[other]; ok { - cp := *n - out = append(out, &cp) - } - } - return out, nil -} - -func (m *mockStorage) CreateEdge(ctx context.Context, e *storage.Edge) error { - if err := ctx.Err(); err != nil { - return err - } - m.mu.Lock() - defer m.mu.Unlock() - m.edges[e.ID] = e - return nil -} - -func (m *mockStorage) GetEdge(ctx context.Context, id string) (*storage.Edge, error) { - if err := ctx.Err(); err != nil { - return nil, err - } - m.mu.RLock() - defer m.mu.RUnlock() - if e, ok := m.edges[id]; ok { - cp := *e - return &cp, nil - } - return nil, storage.ErrEdgeNotFound -} - -func (m *mockStorage) InvalidateEdge(ctx context.Context, id string) error { - return nil -} - -func (m *mockStorage) DeleteEdge(ctx context.Context, id string) error { - if err := ctx.Err(); err != nil { - return err - } - m.mu.Lock() - defer m.mu.Unlock() - delete(m.edges, id) - return nil -} - -func (m *mockStorage) GetEdgesFrom(ctx context.Context, nodeID string) ([]*storage.Edge, error) { - if err := ctx.Err(); err != nil { - return nil, err - } - m.mu.RLock() - defer m.mu.RUnlock() - var out []*storage.Edge - for _, e := range m.edges { - if e.FromID == nodeID { - cp := *e - out = append(out, &cp) - } - } - return out, nil -} - -func (m *mockStorage) GetEdgesTo(ctx context.Context, nodeID string) ([]*storage.Edge, error) { - if err := ctx.Err(); err != nil { - return nil, err - } - m.mu.RLock() - defer m.mu.RUnlock() - var out []*storage.Edge - for _, e := range m.edges { - if e.ToID == nodeID { - cp := *e - out = append(out, &cp) - } - } - return out, nil -} - -func (m *mockStorage) GetValidEdgesFrom(ctx context.Context, nodeID string, at time.Time) ([]*storage.Edge, error) { - if err := ctx.Err(); err != nil { - return nil, err - } - if at.IsZero() { - at = time.Now().UTC() - } - m.mu.RLock() - defer m.mu.RUnlock() - var out []*storage.Edge - for _, e := range m.edges { - if e.FromID == nodeID && validEdgeAt(e, at) { - cp := *e - out = append(out, &cp) - } - } - return out, nil -} - -func (m *mockStorage) GetValidEdgesTo(ctx context.Context, nodeID string, at time.Time) ([]*storage.Edge, error) { - if err := ctx.Err(); err != nil { - return nil, err - } - if at.IsZero() { - at = time.Now().UTC() - } - m.mu.RLock() - defer m.mu.RUnlock() - var out []*storage.Edge - for _, e := range m.edges { - if e.ToID == nodeID && validEdgeAt(e, at) { - cp := *e - out = append(out, &cp) - } - } - return out, nil -} - -func validEdgeAt(e *storage.Edge, at time.Time) bool { - if !e.ValidAt.IsZero() && e.ValidAt.After(at) { - return false - } - if !e.InvalidAt.IsZero() && !e.InvalidAt.After(at) { - return false - } - return true -} - -func (m *mockStorage) GetEdgesBetween(ctx context.Context, nodeIDs []string) ([]*storage.Edge, error) { - if err := ctx.Err(); err != nil { - return nil, err - } - m.mu.RLock() - defer m.mu.RUnlock() - idSet := make(map[string]bool, len(nodeIDs)) - for _, id := range nodeIDs { - idSet[id] = true - } - var out []*storage.Edge - for _, e := range m.edges { - if idSet[e.FromID] && idSet[e.ToID] { - cp := *e - out = append(out, &cp) - } - } - return out, nil -} - -func (m *mockStorage) GetAllEdgesFor(ctx context.Context, nodeIDs []string) ([]*storage.Edge, error) { - if err := ctx.Err(); err != nil { - return nil, err - } - m.mu.RLock() - defer m.mu.RUnlock() - idSet := make(map[string]bool, len(nodeIDs)) - for _, id := range nodeIDs { - idSet[id] = true - } - var out []*storage.Edge - for _, e := range m.edges { - if idSet[e.FromID] || idSet[e.ToID] { - cp := *e - out = append(out, &cp) - } - } - return out, nil -} - -func (m *mockStorage) CountEdges(ctx context.Context, nodeID string) (inbound int, outbound int, err error) { - if err := ctx.Err(); err != nil { - return 0, 0, err - } - m.mu.RLock() - defer m.mu.RUnlock() - for _, e := range m.edges { - if e.FromID == nodeID { - outbound++ - } - if e.ToID == nodeID { - inbound++ - } - } - return inbound, outbound, nil -} - -func (m *mockStorage) CountAllEdges(ctx context.Context) (int, error) { - if err := ctx.Err(); err != nil { - return 0, err - } - m.mu.RLock() - defer m.mu.RUnlock() - return len(m.edges), nil -} - -func (m *mockStorage) CountEdgesBatch(ctx context.Context, nodeIDs []string) (map[string][2]int, error) { - result := make(map[string][2]int, len(nodeIDs)) - for _, id := range nodeIDs { - inbound, outbound, _ := m.CountEdges(ctx, id) - result[id] = [2]int{inbound, outbound} - } - return result, nil -} - -func (m *mockStorage) CheckCycle(ctx context.Context, fromID, toID string) (bool, error) { - if err := ctx.Err(); err != nil { - return false, err - } - m.mu.RLock() - defer m.mu.RUnlock() - seen := map[string]bool{} - var walk func(id string) bool - walk = func(id string) bool { - if id == toID { - return true - } - if seen[id] { - return false - } - seen[id] = true - for _, e := range m.edges { - acyclic := e.Type == "caused_by" || e.Type == "led_to" || e.Type == "supersedes" || - e.Type == "learned_in" || e.Type == "part_of" - if e.ToID == id && acyclic { - if walk(e.FromID) { - return true - } - } - } - return false - } - return walk(fromID), nil -} - -func (m *mockStorage) CreateSession(ctx context.Context, sess *storage.Session) error { - if err := ctx.Err(); err != nil { - return err - } - m.mu.Lock() - defer m.mu.Unlock() - m.sessions[sess.ID] = sess - return nil -} - -func (m *mockStorage) EndSession(ctx context.Context, id string, summary string) error { - if err := ctx.Err(); err != nil { - return err - } - m.mu.Lock() - defer m.mu.Unlock() - if s, ok := m.sessions[id]; ok { - s.Summary = summary - s.EndedAt = time.Now() - } - return nil -} - -func (m *mockStorage) ListSessions(ctx context.Context, project string, limit int) ([]*storage.Session, error) { - if err := ctx.Err(); err != nil { - return nil, err - } - m.mu.RLock() - defer m.mu.RUnlock() - var out []*storage.Session - for _, s := range m.sessions { - if project == "" || s.Project == project { - out = append(out, s) - } - } - return out, nil -} - -func (m *mockStorage) SaveVersion(ctx context.Context, nodeID string, content, changedBy, reason string) error { - if err := ctx.Err(); err != nil { - return err - } - m.mu.Lock() - defer m.mu.Unlock() - vers := m.versions[nodeID] - nextVer := 1 - for _, v := range vers { - if v.Version >= nextVer { - nextVer = v.Version + 1 - } - } - m.versions[nodeID] = append(vers, &storage.NodeVersion{ - NodeID: nodeID, - Content: content, - ChangedBy: changedBy, - Reason: reason, - Version: nextVer, - ChangedAt: time.Now(), - }) - return nil -} - -func (m *mockStorage) GetVersions(ctx context.Context, nodeID string) ([]*storage.NodeVersion, error) { - if err := ctx.Err(); err != nil { - return nil, err - } - m.mu.RLock() - defer m.mu.RUnlock() - out := make([]*storage.NodeVersion, len(m.versions[nodeID])) - copy(out, m.versions[nodeID]) - return out, nil -} - -func (m *mockStorage) SaveEmbedding(ctx context.Context, nodeID, model string, vector []float32) error { - if err := ctx.Err(); err != nil { - return err - } - m.mu.Lock() - defer m.mu.Unlock() - cp := make([]float32, len(vector)) - copy(cp, vector) - m.embeds[nodeID] = cp - return nil -} - -func (m *mockStorage) DeleteEmbedding(ctx context.Context, nodeID string) error { - if err := ctx.Err(); err != nil { - return err - } - m.mu.Lock() - defer m.mu.Unlock() - delete(m.embeds, nodeID) - return nil -} - -func (m *mockStorage) AllEmbeddings(ctx context.Context, _ string) (map[string][]float32, error) { - if err := ctx.Err(); err != nil { - return nil, err - } - m.mu.RLock() - defer m.mu.RUnlock() - out := make(map[string][]float32, len(m.embeds)) - for k, v := range m.embeds { - cp := make([]float32, len(v)) - copy(cp, v) - out[k] = cp - } - return out, nil -} - -func (m *mockStorage) GetEmbedding(ctx context.Context, nodeID string) ([]float32, string, error) { - m.mu.RLock() - defer m.mu.RUnlock() - if v, ok := m.embeds[nodeID]; ok { - cp := make([]float32, len(v)) - copy(cp, v) - return cp, "", nil - } - return nil, "", nil -} - -func (m *mockStorage) GetEmbeddingsBatch(ctx context.Context, model string, offset, limit int) (map[string][]float32, error) { - return m.AllEmbeddings(ctx, model) -} - -func (m *mockStorage) AddFileWatch(ctx context.Context, filePath, nodeID, gitHash string) error { - if err := ctx.Err(); err != nil { - return err - } - m.mu.Lock() - defer m.mu.Unlock() - m.watches = append(m.watches, fileWatch{filePath, nodeID, gitHash}) - return nil -} - -func (m *mockStorage) AddReplayEvent(ctx context.Context, sessionID, data string) error { return nil } - -func (m *mockStorage) GetReplayEvents(ctx context.Context, sessionID string) ([]*storage.ReplayEvent, error) { - return nil, nil -} -func (m *mockStorage) LogAccess(ctx context.Context, nodeID string) error { return nil } -func (m *mockStorage) FlushAccessLog(ctx context.Context) (int, error) { return 0, nil } - -func (m *mockStorage) SaveNodeMetadata(_ context.Context, nodeID string, meta map[string]string) error { - m.mu.Lock() - defer m.mu.Unlock() - m.metadata[nodeID] = meta - return nil -} - -func (m *mockStorage) LoadNodeMetadata(_ context.Context, nodeIDs []string) (map[string]map[string]string, error) { - m.mu.RLock() - defer m.mu.RUnlock() - out := make(map[string]map[string]string, len(nodeIDs)) - for _, id := range nodeIDs { - if meta, ok := m.metadata[id]; ok { - out[id] = meta - } - } - return out, nil -} - -func (m *mockStorage) NodeStats(ctx context.Context) (map[string]int, int, error) { - m.mu.RLock() - defer m.mu.RUnlock() - stats := make(map[string]int) - total := 0 - for _, n := range m.nodes { - stats[n.Type]++ - total++ - } - return stats, total, nil -} - -func (m *mockStorage) DoctorStats(ctx context.Context) (storage.DoctorStatsResult, error) { - m.mu.RLock() - defer m.mu.RUnlock() - var r storage.DoctorStatsResult - for _, n := range m.nodes { - r.TotalNodes++ - if n.Confidence < 0.2 { - r.LowConfidence++ - } - if n.Pinned { - r.Pinned++ - } - } - return r, nil -} - -func (m *mockStorage) LastUpdated(ctx context.Context) (time.Time, error) { - return time.Time{}, nil -} - -func (m *mockStorage) TopConnected(ctx context.Context, limit int) ([]string, error) { - return nil, nil -} - -func (m *mockStorage) SaveSignature(ctx context.Context, nodeID, signature string) error { - return nil -} - -func (m *mockStorage) GetAllSignatures(ctx context.Context) (map[string]string, error) { - return nil, nil -} - -func (m *mockStorage) WithTx(ctx context.Context, fn func(storage.Storage) error) error { - return fn(m) -} -func (m *mockStorage) Close() error { return nil } - -// --------------------------------------------------------------------------- -// mockGraph — in-memory implementation of graph.Graph backed by storage -// --------------------------------------------------------------------------- - -type mockGraph struct { - store storage.Storage -} - -func newMockGraph(store storage.Storage) *mockGraph { - return &mockGraph{store: store} -} - -func (g *mockGraph) AddNode(ctx context.Context, n *storage.Node) error { - return g.store.CreateNode(ctx, n) -} - -func (g *mockGraph) AddEdge(ctx context.Context, e *storage.Edge) error { - return g.store.CreateEdge(ctx, e) -} - -func (g *mockGraph) RemoveNode(ctx context.Context, id string) error { - return g.store.DeleteNode(ctx, id) -} - -func (g *mockGraph) RemoveEdge(ctx context.Context, id string) error { - return g.store.DeleteEdge(ctx, id) -} - -func (g *mockGraph) ExtractSubgraph(ctx context.Context, startID string, maxDepth int) (*graph.Subgraph, error) { - ids, err := g.BFS(ctx, startID, maxDepth) - if err != nil { - return nil, err - } - sg := &graph.Subgraph{} - for _, id := range ids { - n, err := g.store.GetNode(ctx, id) - if err == nil { - sg.Nodes = append(sg.Nodes, n) - } - } - idSet := make(map[string]bool, len(ids)) - for _, id := range ids { - idSet[id] = true - } - for _, id := range ids { - edges, _ := g.store.GetEdgesFrom(ctx, id) - for _, e := range edges { - if idSet[e.ToID] { - sg.Edges = append(sg.Edges, e) - } - } - } - return sg, nil -} - -func (g *mockGraph) BFS(ctx context.Context, startID string, maxDepth int) ([]string, error) { - _, err := g.store.GetNode(ctx, startID) - if err != nil { - return nil, nil - } - visited := map[string]bool{startID: true} - queue := []struct { - id string - depth int - }{{startID, 0}} - var result []string - result = append(result, startID) - - for len(queue) > 0 { - curr := queue[0] - queue = queue[1:] - if curr.depth >= maxDepth { - continue - } - edges, _ := g.store.GetEdgesFrom(ctx, curr.id) - edgesTo, _ := g.store.GetEdgesTo(ctx, curr.id) - allEdges := append(edges, edgesTo...) - for _, e := range allEdges { - var next string - if e.FromID == curr.id { - next = e.ToID - } else { - next = e.FromID - } - if !visited[next] { - visited[next] = true - result = append(result, next) - queue = append(queue, struct { - id string - depth int - }{next, curr.depth + 1}) - } - } - } - return result, nil -} - -func (g *mockGraph) IntentBFS(ctx context.Context, startID string, maxDepth int, queryIntent intent.Intent) ([]string, error) { - // For mock, delegate to plain BFS (intent weights are ignored) - return g.BFS(ctx, startID, maxDepth) -} - -func (g *mockGraph) Impact(ctx context.Context, filePath string, maxDepth int) ([]string, error) { - return nil, nil -} - -func (g *mockGraph) Ancestors(ctx context.Context, id string) ([]string, error) { - return nil, nil -} - -func (g *mockGraph) Descendants(ctx context.Context, id string) ([]string, error) { - return nil, nil -} - -// --------------------------------------------------------------------------- -// helper -// --------------------------------------------------------------------------- - -func newTestEngine() *Engine { - ms := newMockStorage() - return New(ms, newMockGraph(ms)) -} - -// --------------------------------------------------------------------------- -// Tests -// --------------------------------------------------------------------------- - -// TestMockStorageCompiles verifies mockStorage implements storage.Storage. func TestMockStorageCompiles(t *testing.T) { var _ storage.Storage = newMockStorage() } diff --git a/integration_api_test.go b/integration_api_test.go new file mode 100644 index 0000000..88239d1 --- /dev/null +++ b/integration_api_test.go @@ -0,0 +1,363 @@ +//go:build integration + +//nolint:noctx +package yaad_test + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "sync" + "testing" + "time" + + "github.com/GrayCodeAI/yaad/config" + "github.com/GrayCodeAI/yaad/engine" + "github.com/GrayCodeAI/yaad/internal/server" + yaadtls "github.com/GrayCodeAI/yaad/internal/tls" + "github.com/GrayCodeAI/yaad/profile" + "github.com/GrayCodeAI/yaad/storage" + "github.com/GrayCodeAI/yaad/utils" +) + +// This file is part of the yaad_test integration suite. It holds the utils, +// edge-case, profile-merge, TLS, config, storage, REST API, and concurrency +// tests moved verbatim out of integration_test.go for readability; behavior +// is unchanged. + +func TestUtilsShortID(t *testing.T) { + cases := []struct{ input, expected string }{ + {"abcdefghijklmnop", "abcdefgh"}, + {"short", "short"}, + {"12345678", "12345678"}, + {"", ""}, + {"ab", "ab"}, + } + for _, c := range cases { + got := utils.ShortID(c.input) + if got != c.expected { + t.Errorf("ShortID(%q) = %q, want %q", c.input, got, c.expected) + } + } +} + +func TestEdgeCaseEmptyRecall(t *testing.T) { + eng, cleanup := setup(t) + defer cleanup() + + // Recall on empty DB should return empty, not error + result, err := eng.Recall(context.Background(), engine.RecallOpts{Query: "nonexistent", Limit: 5}) + if err != nil { + t.Fatal(err) + } + if len(result.Nodes) != 0 { + t.Errorf("empty recall: expected 0 nodes, got %d", len(result.Nodes)) + } +} + +func TestEdgeCaseContextEmpty(t *testing.T) { + eng, cleanup := setup(t) + defer cleanup() + + // Context on empty DB should return empty, not error + result, err := eng.Context(context.Background(), "") + if err != nil { + t.Fatal(err) + } + if result == nil { + t.Error("context: should return empty result, not nil") + } +} + +func TestEdgeCaseForgetNonexistent(t *testing.T) { + eng, cleanup := setup(t) + defer cleanup() + + // Forget nonexistent node should error gracefully + err := eng.Forget(context.Background(), "nonexistent-id-12345678") + if err == nil { + t.Error("forget: should error on nonexistent node") + } +} + +func TestProfileMerge(t *testing.T) { + a := &profile.Profile{ + Project: "test", + Static: []string{"Use jose", "Use NATS"}, + Dynamic: []string{"[task] rate limiting"}, + Stack: []string{"TypeScript", "NATS"}, + } + b := &profile.Profile{ + Static: []string{"Prefer tabs", "Use jose"}, // "Use jose" is duplicate + Dynamic: []string{"[bug] auth race"}, + Stack: []string{"PostgreSQL", "NATS"}, // "NATS" is duplicate + } + merged := profile.Merge(a, b) + // Static should be deduped + if len(merged.Static) != 3 { // jose, NATS, tabs + t.Errorf("merge: expected 3 static, got %d: %v", len(merged.Static), merged.Static) + } + // Stack should be deduped + if len(merged.Stack) != 3 { // TypeScript, NATS, PostgreSQL + t.Errorf("merge: expected 3 stack, got %d: %v", len(merged.Stack), merged.Stack) + } + // Dynamic should be combined (not deduped) + if len(merged.Dynamic) != 2 { + t.Errorf("merge: expected 2 dynamic, got %d", len(merged.Dynamic)) + } +} + +func TestMultipleRememberAndRecall(t *testing.T) { + eng, cleanup := setup(t) + defer cleanup() + + // Store 20 memories of different types + types := []string{"convention", "decision", "bug", "spec", "task"} + for i := 0; i < 20; i++ { + eng.Remember(context.Background(), engine.RememberInput{ + Type: types[i%len(types)], + Content: fmt.Sprintf("Memory item %d about topic %d", i, i%5), + Scope: "project", + }) + } + + // Recall should find results + result, err := eng.Recall(context.Background(), engine.RecallOpts{Query: "topic", Limit: 10}) + if err != nil { + t.Fatal(err) + } + if len(result.Nodes) == 0 { + t.Error("bulk recall: expected nodes") + } + t.Logf("Stored 20, recalled %d nodes", len(result.Nodes)) + + // Status should show correct counts + st, _ := eng.Status(context.Background(), "") + if st.Nodes < 20 { + t.Errorf("status: expected ≥20 nodes, got %d", st.Nodes) + } +} + +func TestTLSCertGeneration(t *testing.T) { + dir := t.TempDir() + cfg := yaadtls.Config{Enabled: true} + + tlsCfg, err := yaadtls.TLSConfig(cfg, dir) + if err != nil { + t.Fatal(err) + } + if tlsCfg == nil { + t.Fatal("tls: nil config returned") + } + if len(tlsCfg.Certificates) == 0 { + t.Error("tls: no certificates generated") + } + + // Verify cert files were created + if _, err := os.Stat(filepath.Join(dir, "cert.pem")); err != nil { + t.Error("tls: cert.pem not created") + } + if _, err := os.Stat(filepath.Join(dir, "key.pem")); err != nil { + t.Error("tls: key.pem not created") + } +} + +func TestConfigDefaults(t *testing.T) { + cfg := config.Default() + if cfg.Server.Port != 3456 { + t.Errorf("config: expected port 3456, got %d", cfg.Server.Port) + } + if cfg.Decay.HalfLifeDays != 30 { + t.Errorf("config: expected half_life 30, got %d", cfg.Decay.HalfLifeDays) + } +} + +func TestStorageCreateAndQuery(t *testing.T) { + dir := t.TempDir() + store, err := storage.NewStore(filepath.Join(dir, "test.db")) + if err != nil { + t.Fatal(err) + } + defer func() { store.Close() }() + + ctx := context.Background() + + // Create node + node := &storage.Node{ + ID: "test-node-1", Type: "convention", Content: "Test content", + ContentHash: "hash1", Scope: "project", Tier: 1, Confidence: 1.0, Version: 1, + } + if err := store.CreateNode(ctx, node); err != nil { + t.Fatal(err) + } + + // Get node + got, err := store.GetNode(ctx, "test-node-1") + if err != nil { + t.Fatal(err) + } + if got.Content != "Test content" { + t.Errorf("storage: expected 'Test content', got '%s'", got.Content) + } + + // Create edge + edge := &storage.Edge{ + ID: "test-edge-1", FromID: "test-node-1", ToID: "test-node-1", + Type: "relates_to", Acyclic: false, Weight: 1.0, + } + if err := store.CreateEdge(ctx, edge); err != nil { + t.Fatal(err) + } + + // Get neighbors + neighbors, err := store.GetNeighbors(ctx, "test-node-1") + if err != nil { + t.Fatal(err) + } + if len(neighbors) == 0 { + t.Error("storage: expected neighbors") + } + + // Version history + if err := store.SaveVersion(ctx, "test-node-1", "old content", "test", "test update"); err != nil { + t.Fatal(err) + } + versions, err := store.GetVersions(ctx, "test-node-1") + if err != nil { + t.Fatal(err) + } + if len(versions) == 0 { + t.Error("storage: expected version history") + } +} + +func TestRESTAPI(t *testing.T) { + eng, cleanup := setup(t) + defer cleanup() + + mux := http.NewServeMux() + rest := server.NewRESTServer(eng, "") + rest.RegisterRoutes(mux) + ts := httptest.NewServer(mux) + defer ts.Close() + + // POST /yaad/remember + body, _ := json.Marshal(engine.RememberInput{ + Type: "convention", Content: "Always use TypeScript strict mode", Scope: "project", + }) + resp, err := http.Post(ts.URL+"/yaad/remember", "application/json", bytes.NewReader(body)) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != 201 { + t.Errorf("remember: expected 201, got %d", resp.StatusCode) + } + var node storage.Node + json.NewDecoder(resp.Body).Decode(&node) + resp.Body.Close() + if node.ID == "" { + t.Error("remember: empty node ID") + } + + // POST /yaad/recall + body, _ = json.Marshal(engine.RecallOpts{Query: "TypeScript", Limit: 5}) + resp, err = http.Post(ts.URL+"/yaad/recall", "application/json", bytes.NewReader(body)) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != 200 { + t.Errorf("recall: expected 200, got %d", resp.StatusCode) + } + var result engine.RecallResult + json.NewDecoder(resp.Body).Decode(&result) + resp.Body.Close() + if len(result.Nodes) == 0 { + t.Error("recall: expected nodes, got none") + } + + // GET /yaad/health + resp, _ = http.Get(ts.URL + "/yaad/health") + if resp.StatusCode != 200 { + t.Errorf("health: expected 200, got %d", resp.StatusCode) + } + resp.Body.Close() + + // GET /yaad/context + resp, _ = http.Get(ts.URL + "/yaad/context") + if resp.StatusCode != 200 { + t.Errorf("context: expected 200, got %d", resp.StatusCode) + } + resp.Body.Close() +} + +// TestConcurrentSQLiteAccess verifies that concurrent Remember and Recall operations +// against the real SQLite backend do not race or corrupt data. +func TestConcurrentSQLiteAccess(t *testing.T) { + eng, cleanup := setup(t) + defer cleanup() + + var wg sync.WaitGroup + numWriters := 5 + numReaders := 5 + opsPerGoroutine := 10 + + // Writers + for i := 0; i < numWriters; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + for j := 0; j < opsPerGoroutine; j++ { + _, err := eng.Remember(context.Background(), engine.RememberInput{ + Type: "convention", + Content: fmt.Sprintf("writer-%d-op-%d", idx, j), + Scope: "project", + Project: "concurrent-test", + }) + if err != nil { + // Under CI load, occasional SQLITE_BUSY is expected even with + // _busy_timeout. Skip individual operations rather than failing. + if strings.Contains(err.Error(), "database is locked") { + time.Sleep(10 * time.Millisecond) + continue + } + t.Errorf("writer %d op %d failed: %v", idx, j, err) + } + } + }(i) + } + + // Readers + for i := 0; i < numReaders; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + for j := 0; j < opsPerGoroutine; j++ { + _, err := eng.Recall(context.Background(), engine.RecallOpts{ + Query: "writer", + Project: "concurrent-test", + Limit: 10, + }) + if err != nil { + t.Errorf("reader %d op %d failed: %v", idx, j, err) + } + } + }(i) + } + + wg.Wait() + + st, err := eng.Status(context.Background(), "concurrent-test") + if err != nil { + t.Fatalf("status failed: %v", err) + } + expectedNodes := numWriters * opsPerGoroutine + if st.Nodes < expectedNodes { + t.Errorf("expected at least %d nodes, got %d", expectedNodes, st.Nodes) + } +} diff --git a/integration_features_test.go b/integration_features_test.go new file mode 100644 index 0000000..33af961 --- /dev/null +++ b/integration_features_test.go @@ -0,0 +1,358 @@ +//go:build integration + +//nolint:noctx +package yaad_test + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/GrayCodeAI/yaad/engine" + "github.com/GrayCodeAI/yaad/ingest" + intentpkg "github.com/GrayCodeAI/yaad/intent" + "github.com/GrayCodeAI/yaad/internal/bench" + "github.com/GrayCodeAI/yaad/skill" + "github.com/GrayCodeAI/yaad/storage" +) + +// This file is part of the yaad_test integration suite. It holds the +// skills, benchmark, profile, conflict, temporal, dedup, compaction, +// mental-model, intent/phase-6, and privacy tests moved verbatim out of +// integration_test.go for readability; behavior is unchanged. + +func TestPhase5Skills(t *testing.T) { + eng, cleanup := setup(t) + defer cleanup() + + sk := &skill.Skill{ + Name: "deploy", + Description: "Deploy the application", + Steps: []skill.Step{ + {Order: 1, Description: "Run tests", Command: "pnpm test"}, + {Order: 2, Description: "Build", Command: "pnpm build"}, + {Order: 3, Description: "Deploy", Command: "fly deploy"}, + }, + } + node, err := skill.Store(context.Background(), eng, sk, "") + if err != nil { + t.Fatal(err) + } + if node.ID == "" { + t.Error("skill store: empty node ID") + } + + // List skills + skills, err := skill.ListSkills(context.Background(), eng.Store(), "") + if err != nil { + t.Fatal(err) + } + if len(skills) == 0 { + t.Error("skill list: expected skills > 0") + } + + // Replay + replay := skill.Replay(sk) + if !strings.Contains(replay, "deploy") { + t.Error("skill replay missing content") + } +} + +func TestPhase5Benchmark(t *testing.T) { + eng, cleanup := setup(t) + defer cleanup() + + // Seed some memories + eng.Remember(context.Background(), engine.RememberInput{Type: "convention", Content: "Use jose not jsonwebtoken for Edge compatibility", Scope: "project"}) + eng.Remember(context.Background(), engine.RememberInput{Type: "decision", Content: "Chose NATS over Redis Streams for event bus", Scope: "project"}) + eng.Remember(context.Background(), engine.RememberInput{Type: "bug", Content: "Token refresh race condition in auth middleware", Scope: "project"}) + + result := bench.Run(context.Background(), eng, bench.DefaultQAs(), 2, 10) + if result.Total == 0 { + t.Error("benchmark: no questions evaluated") + } + // R@5 should be > 0 with seeded data + if result.HitAtK[5] == 0 { + t.Log("benchmark: R@5=0 (may be ok with small dataset)") + } + t.Logf("Benchmark:\n%s", result.String()) +} + +func TestUserProfile(t *testing.T) { + eng, cleanup := setup(t) + defer cleanup() + + eng.Remember(context.Background(), engine.RememberInput{Type: "convention", Content: "Use jose for JWT auth", Scope: "project"}) + eng.Remember(context.Background(), engine.RememberInput{Type: "decision", Content: "Chose NATS for event bus", Scope: "project"}) + eng.Remember(context.Background(), engine.RememberInput{Type: "task", Content: "Add rate limiting", Scope: "project"}) + eng.Remember(context.Background(), engine.RememberInput{Type: "preference", Content: "Prefers functional style", Scope: "project"}) + + p, err := eng.Profile(context.Background(), "") + if err != nil { + t.Fatal(err) + } + if len(p.Static) == 0 { + t.Error("profile: no static facts") + } + if p.Summary == "" { + t.Error("profile: empty summary") + } + formatted := p.Format() + if !strings.Contains(formatted, "User Profile") { + t.Error("profile: formatted output missing header") + } + t.Logf("Profile:\n%s", formatted) +} + +func TestConflictResolver(t *testing.T) { + eng, cleanup := setup(t) + defer cleanup() + + // Store original convention + old, _ := eng.Remember(context.Background(), engine.RememberInput{ + Type: "convention", Content: "Use jsonwebtoken library for JWT", Scope: "project", + }) + + // Store contradicting convention (should supersede) + newNode, _ := eng.Remember(context.Background(), engine.RememberInput{ + Type: "convention", Content: "Use jose instead of jsonwebtoken for Edge compatibility", Scope: "project", + }) + + // Verify old node confidence was lowered + oldUpdated, _ := eng.Store().GetNode(context.Background(), old.ID) + if oldUpdated.Confidence >= 1.0 { + t.Errorf("conflict: old node confidence should be lowered, got %.2f", oldUpdated.Confidence) + } + + // Verify supersedes edge exists + edges, _ := eng.Store().GetEdgesFrom(context.Background(), newNode.ID) + hasSupersedes := false + for _, e := range edges { + if e.Type == "supersedes" && e.ToID == old.ID { + hasSupersedes = true + } + } + if !hasSupersedes { + t.Error("conflict: supersedes edge not created") + } +} + +func TestTemporalBackbone(t *testing.T) { + eng, cleanup := setup(t) + defer cleanup() + + n1, _ := eng.Remember(context.Background(), engine.RememberInput{Type: "convention", Content: "First convention", Scope: "project", Project: "test"}) + n2, _ := eng.Remember(context.Background(), engine.RememberInput{Type: "decision", Content: "Second decision", Scope: "project", Project: "test"}) + n3, _ := eng.Remember(context.Background(), engine.RememberInput{Type: "bug", Content: "Third bug report", Scope: "project", Project: "test"}) + + // Verify temporal chain: n1 → n2 → n3 + edges1, _ := eng.Store().GetEdgesFrom(context.Background(), n1.ID) + hasLink12 := false + for _, e := range edges1 { + if e.Type == "learned_in" && e.ToID == n2.ID { + hasLink12 = true + } + } + edges2, _ := eng.Store().GetEdgesFrom(context.Background(), n2.ID) + hasLink23 := false + for _, e := range edges2 { + if e.Type == "learned_in" && e.ToID == n3.ID { + hasLink23 = true + } + } + if !hasLink12 { + t.Error("temporal: n1→n2 learned_in edge missing") + } + if !hasLink23 { + t.Error("temporal: n2→n3 learned_in edge missing") + } +} + +func TestDedupRollingWindow(t *testing.T) { + eng, cleanup := setup(t) + defer cleanup() + + n1, _ := eng.Remember(context.Background(), engine.RememberInput{ + Type: "convention", Content: "Use jose for JWT auth", Scope: "project", + }) + // Same content again — should return same node (dedup) + n2, _ := eng.Remember(context.Background(), engine.RememberInput{ + Type: "convention", Content: "Use jose for JWT auth", Scope: "project", + }) + + if n1.ID != n2.ID { + t.Errorf("dedup: expected same node ID, got %s and %s", n1.ID[:8], n2.ID[:8]) + } +} + +func TestCompaction(t *testing.T) { + eng, cleanup := setup(t) + defer cleanup() + + // Store 5 low-confidence nodes + for i := 0; i < 5; i++ { + n, _ := eng.Remember(context.Background(), engine.RememberInput{ + Type: "decision", Content: fmt.Sprintf("Old decision %d about something", i), Scope: "project", + }) + node, _ := eng.Store().GetNode(context.Background(), n.ID) + node.Confidence = 0.2 + node.AccessCount = 0 + eng.Store().UpdateNode(context.Background(), node) + } + + // Run compaction + compacted, err := eng.Compact(context.Background(), "") + if err != nil { + t.Fatal(err) + } + if compacted == 0 { + t.Error("compaction: expected nodes to be compacted") + } + t.Logf("Compacted %d nodes", compacted) +} + +func TestMentalModel(t *testing.T) { + eng, cleanup := setup(t) + defer cleanup() + + eng.Remember(context.Background(), engine.RememberInput{Type: "convention", Content: "Use jose for JWT", Scope: "project"}) + eng.Remember(context.Background(), engine.RememberInput{Type: "decision", Content: "Chose NATS for events", Scope: "project"}) + eng.Remember(context.Background(), engine.RememberInput{Type: "task", Content: "Add rate limiting", Scope: "project"}) + + model, err := eng.MentalModel(context.Background(), "") + if err != nil { + t.Fatal(err) + } + if model.Summary == "" { + t.Error("mental model: empty summary") + } + if len(model.Conventions) == 0 { + t.Error("mental model: no conventions") + } + formatted := model.Format() + if formatted == "" { + t.Error("mental model: empty formatted output") + } + t.Logf("Mental model:\n%s", formatted) +} + +func TestPhase6IntentClassifier(t *testing.T) { + cases := []struct { + query string + expected intentpkg.Intent + }{ + {"why did we choose NATS over Redis?", intentpkg.IntentWhy}, + {"when did we fix the auth bug?", intentpkg.IntentWhen}, + {"how to deploy the application?", intentpkg.IntentHow}, + {"what is the auth subsystem?", intentpkg.IntentWhat}, + {"which library should I use for JWT?", intentpkg.IntentWho}, + {"recall auth middleware", intentpkg.IntentGeneral}, + } + for _, c := range cases { + got := intentpkg.Classify(c.query) + if got != c.expected { + t.Errorf("Classify(%q) = %s, want %s", c.query, got, c.expected) + } + } +} + +func TestPhase6IntentAwareRetrieval(t *testing.T) { + eng, cleanup := setup(t) + defer cleanup() + + // Seed memories + decision, _ := eng.Remember(context.Background(), engine.RememberInput{Type: "decision", Content: "Chose NATS over Redis Streams for event bus", Scope: "project"}) + convention, _ := eng.Remember(context.Background(), engine.RememberInput{Type: "convention", Content: "Use NATS client v2 for all event publishing", Scope: "project"}) + + // Link: decision led_to convention + eng.Graph().AddEdge(context.Background(), &storage.Edge{ + ID: "e-test", FromID: decision.ID, ToID: convention.ID, Type: "led_to", Weight: 1.0, + }) + + // Why query should find the decision via causal traversal + result, err := eng.Recall(context.Background(), engine.RecallOpts{Query: "why NATS", Depth: 2, Limit: 10}) + if err != nil { + t.Fatal(err) + } + if len(result.Nodes) == 0 { + t.Error("intent-aware recall returned no nodes") + } + // Should find both decision and convention via causal chain + found := map[string]bool{} + for _, n := range result.Nodes { + found[n.Type] = true + } + t.Logf("Why query found types: %v", found) +} + +func TestPhase6DualStream(t *testing.T) { + eng, cleanup := setup(t) + defer cleanup() + + ds := ingest.New(eng) + defer ds.Stop() + + // Fast path should return immediately + node, err := ds.Remember(context.Background(), engine.RememberInput{ + Type: "convention", Content: "Use jose not jsonwebtoken", Scope: "project", + }) + if err != nil { + t.Fatal(err) + } + if node.ID == "" { + t.Error("dual stream: empty node ID") + } + + // Second remember should create temporal backbone edge + node2, err := ds.Remember(context.Background(), engine.RememberInput{ + Type: "decision", Content: "Chose RS256 for JWT", Scope: "project", + }) + if err != nil { + t.Fatal(err) + } + if node2.ID == "" { + t.Error("dual stream: second node empty ID") + } + + // Give slow path time to run and release DB lock + // Retry up to 500ms (slow path runs async) + var hasTemporalEdge bool + for i := 0; i < 10; i++ { + time.Sleep(50 * time.Millisecond) + edges, _ := eng.Store().GetEdgesFrom(context.Background(), node.ID) + for _, e := range edges { + if e.ToID == node2.ID && e.Type == "learned_in" { + hasTemporalEdge = true + } + } + if hasTemporalEdge { + break + } + } + if !hasTemporalEdge { + t.Error("dual stream: temporal backbone edge not created within 500ms") + } +} + +func TestPrivacyFilter(t *testing.T) { + eng, cleanup := setup(t) + defer cleanup() + + // Store content with secrets — should be stripped + node, _ := eng.Remember(context.Background(), engine.RememberInput{ + Type: "convention", + Content: "Use API key sk-1234567890abcdefghijklmnop for auth and AKIA1234567890ABCDEF for AWS", + Scope: "project", + }) + if strings.Contains(node.Content, "sk-1234567890") { + t.Error("privacy: API key not stripped") + } + if strings.Contains(node.Content, "AKIA1234567890") { + t.Error("privacy: AWS key not stripped") + } + if !strings.Contains(node.Content, "[REDACTED]") { + t.Error("privacy: expected [REDACTED] placeholder") + } +} diff --git a/integration_test.go b/integration_test.go index 9801233..d150685 100644 --- a/integration_test.go +++ b/integration_test.go @@ -4,34 +4,18 @@ package yaad_test import ( - "bytes" "context" - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" "os" "path/filepath" "strings" - "sync" "testing" - "time" - "github.com/GrayCodeAI/yaad/config" "github.com/GrayCodeAI/yaad/embeddings" "github.com/GrayCodeAI/yaad/engine" "github.com/GrayCodeAI/yaad/exportimport" "github.com/GrayCodeAI/yaad/graph" "github.com/GrayCodeAI/yaad/hooks" - "github.com/GrayCodeAI/yaad/ingest" - intentpkg "github.com/GrayCodeAI/yaad/intent" - "github.com/GrayCodeAI/yaad/internal/bench" - "github.com/GrayCodeAI/yaad/internal/server" - yaadtls "github.com/GrayCodeAI/yaad/internal/tls" - "github.com/GrayCodeAI/yaad/profile" - "github.com/GrayCodeAI/yaad/skill" "github.com/GrayCodeAI/yaad/storage" - "github.com/GrayCodeAI/yaad/utils" ) func setup(t *testing.T) (*engine.Engine, func()) { @@ -360,668 +344,3 @@ func TestPhase5ExportImport(t *testing.T) { t.Error("obsidian export: expected files > 0") } } - -func TestPhase5Skills(t *testing.T) { - eng, cleanup := setup(t) - defer cleanup() - - sk := &skill.Skill{ - Name: "deploy", - Description: "Deploy the application", - Steps: []skill.Step{ - {Order: 1, Description: "Run tests", Command: "pnpm test"}, - {Order: 2, Description: "Build", Command: "pnpm build"}, - {Order: 3, Description: "Deploy", Command: "fly deploy"}, - }, - } - node, err := skill.Store(context.Background(), eng, sk, "") - if err != nil { - t.Fatal(err) - } - if node.ID == "" { - t.Error("skill store: empty node ID") - } - - // List skills - skills, err := skill.ListSkills(context.Background(), eng.Store(), "") - if err != nil { - t.Fatal(err) - } - if len(skills) == 0 { - t.Error("skill list: expected skills > 0") - } - - // Replay - replay := skill.Replay(sk) - if !strings.Contains(replay, "deploy") { - t.Error("skill replay missing content") - } -} - -func TestPhase5Benchmark(t *testing.T) { - eng, cleanup := setup(t) - defer cleanup() - - // Seed some memories - eng.Remember(context.Background(), engine.RememberInput{Type: "convention", Content: "Use jose not jsonwebtoken for Edge compatibility", Scope: "project"}) - eng.Remember(context.Background(), engine.RememberInput{Type: "decision", Content: "Chose NATS over Redis Streams for event bus", Scope: "project"}) - eng.Remember(context.Background(), engine.RememberInput{Type: "bug", Content: "Token refresh race condition in auth middleware", Scope: "project"}) - - result := bench.Run(context.Background(), eng, bench.DefaultQAs(), 2, 10) - if result.Total == 0 { - t.Error("benchmark: no questions evaluated") - } - // R@5 should be > 0 with seeded data - if result.HitAtK[5] == 0 { - t.Log("benchmark: R@5=0 (may be ok with small dataset)") - } - t.Logf("Benchmark:\n%s", result.String()) -} - -func TestUserProfile(t *testing.T) { - eng, cleanup := setup(t) - defer cleanup() - - eng.Remember(context.Background(), engine.RememberInput{Type: "convention", Content: "Use jose for JWT auth", Scope: "project"}) - eng.Remember(context.Background(), engine.RememberInput{Type: "decision", Content: "Chose NATS for event bus", Scope: "project"}) - eng.Remember(context.Background(), engine.RememberInput{Type: "task", Content: "Add rate limiting", Scope: "project"}) - eng.Remember(context.Background(), engine.RememberInput{Type: "preference", Content: "Prefers functional style", Scope: "project"}) - - p, err := eng.Profile(context.Background(), "") - if err != nil { - t.Fatal(err) - } - if len(p.Static) == 0 { - t.Error("profile: no static facts") - } - if p.Summary == "" { - t.Error("profile: empty summary") - } - formatted := p.Format() - if !strings.Contains(formatted, "User Profile") { - t.Error("profile: formatted output missing header") - } - t.Logf("Profile:\n%s", formatted) -} - -func TestConflictResolver(t *testing.T) { - eng, cleanup := setup(t) - defer cleanup() - - // Store original convention - old, _ := eng.Remember(context.Background(), engine.RememberInput{ - Type: "convention", Content: "Use jsonwebtoken library for JWT", Scope: "project", - }) - - // Store contradicting convention (should supersede) - newNode, _ := eng.Remember(context.Background(), engine.RememberInput{ - Type: "convention", Content: "Use jose instead of jsonwebtoken for Edge compatibility", Scope: "project", - }) - - // Verify old node confidence was lowered - oldUpdated, _ := eng.Store().GetNode(context.Background(), old.ID) - if oldUpdated.Confidence >= 1.0 { - t.Errorf("conflict: old node confidence should be lowered, got %.2f", oldUpdated.Confidence) - } - - // Verify supersedes edge exists - edges, _ := eng.Store().GetEdgesFrom(context.Background(), newNode.ID) - hasSupersedes := false - for _, e := range edges { - if e.Type == "supersedes" && e.ToID == old.ID { - hasSupersedes = true - } - } - if !hasSupersedes { - t.Error("conflict: supersedes edge not created") - } -} - -func TestTemporalBackbone(t *testing.T) { - eng, cleanup := setup(t) - defer cleanup() - - n1, _ := eng.Remember(context.Background(), engine.RememberInput{Type: "convention", Content: "First convention", Scope: "project", Project: "test"}) - n2, _ := eng.Remember(context.Background(), engine.RememberInput{Type: "decision", Content: "Second decision", Scope: "project", Project: "test"}) - n3, _ := eng.Remember(context.Background(), engine.RememberInput{Type: "bug", Content: "Third bug report", Scope: "project", Project: "test"}) - - // Verify temporal chain: n1 → n2 → n3 - edges1, _ := eng.Store().GetEdgesFrom(context.Background(), n1.ID) - hasLink12 := false - for _, e := range edges1 { - if e.Type == "learned_in" && e.ToID == n2.ID { - hasLink12 = true - } - } - edges2, _ := eng.Store().GetEdgesFrom(context.Background(), n2.ID) - hasLink23 := false - for _, e := range edges2 { - if e.Type == "learned_in" && e.ToID == n3.ID { - hasLink23 = true - } - } - if !hasLink12 { - t.Error("temporal: n1→n2 learned_in edge missing") - } - if !hasLink23 { - t.Error("temporal: n2→n3 learned_in edge missing") - } -} - -func TestDedupRollingWindow(t *testing.T) { - eng, cleanup := setup(t) - defer cleanup() - - n1, _ := eng.Remember(context.Background(), engine.RememberInput{ - Type: "convention", Content: "Use jose for JWT auth", Scope: "project", - }) - // Same content again — should return same node (dedup) - n2, _ := eng.Remember(context.Background(), engine.RememberInput{ - Type: "convention", Content: "Use jose for JWT auth", Scope: "project", - }) - - if n1.ID != n2.ID { - t.Errorf("dedup: expected same node ID, got %s and %s", n1.ID[:8], n2.ID[:8]) - } -} - -func TestCompaction(t *testing.T) { - eng, cleanup := setup(t) - defer cleanup() - - // Store 5 low-confidence nodes - for i := 0; i < 5; i++ { - n, _ := eng.Remember(context.Background(), engine.RememberInput{ - Type: "decision", Content: fmt.Sprintf("Old decision %d about something", i), Scope: "project", - }) - node, _ := eng.Store().GetNode(context.Background(), n.ID) - node.Confidence = 0.2 - node.AccessCount = 0 - eng.Store().UpdateNode(context.Background(), node) - } - - // Run compaction - compacted, err := eng.Compact(context.Background(), "") - if err != nil { - t.Fatal(err) - } - if compacted == 0 { - t.Error("compaction: expected nodes to be compacted") - } - t.Logf("Compacted %d nodes", compacted) -} - -func TestMentalModel(t *testing.T) { - eng, cleanup := setup(t) - defer cleanup() - - eng.Remember(context.Background(), engine.RememberInput{Type: "convention", Content: "Use jose for JWT", Scope: "project"}) - eng.Remember(context.Background(), engine.RememberInput{Type: "decision", Content: "Chose NATS for events", Scope: "project"}) - eng.Remember(context.Background(), engine.RememberInput{Type: "task", Content: "Add rate limiting", Scope: "project"}) - - model, err := eng.MentalModel(context.Background(), "") - if err != nil { - t.Fatal(err) - } - if model.Summary == "" { - t.Error("mental model: empty summary") - } - if len(model.Conventions) == 0 { - t.Error("mental model: no conventions") - } - formatted := model.Format() - if formatted == "" { - t.Error("mental model: empty formatted output") - } - t.Logf("Mental model:\n%s", formatted) -} - -func TestPhase6IntentClassifier(t *testing.T) { - cases := []struct { - query string - expected intentpkg.Intent - }{ - {"why did we choose NATS over Redis?", intentpkg.IntentWhy}, - {"when did we fix the auth bug?", intentpkg.IntentWhen}, - {"how to deploy the application?", intentpkg.IntentHow}, - {"what is the auth subsystem?", intentpkg.IntentWhat}, - {"which library should I use for JWT?", intentpkg.IntentWho}, - {"recall auth middleware", intentpkg.IntentGeneral}, - } - for _, c := range cases { - got := intentpkg.Classify(c.query) - if got != c.expected { - t.Errorf("Classify(%q) = %s, want %s", c.query, got, c.expected) - } - } -} - -func TestPhase6IntentAwareRetrieval(t *testing.T) { - eng, cleanup := setup(t) - defer cleanup() - - // Seed memories - decision, _ := eng.Remember(context.Background(), engine.RememberInput{Type: "decision", Content: "Chose NATS over Redis Streams for event bus", Scope: "project"}) - convention, _ := eng.Remember(context.Background(), engine.RememberInput{Type: "convention", Content: "Use NATS client v2 for all event publishing", Scope: "project"}) - - // Link: decision led_to convention - eng.Graph().AddEdge(context.Background(), &storage.Edge{ - ID: "e-test", FromID: decision.ID, ToID: convention.ID, Type: "led_to", Weight: 1.0, - }) - - // Why query should find the decision via causal traversal - result, err := eng.Recall(context.Background(), engine.RecallOpts{Query: "why NATS", Depth: 2, Limit: 10}) - if err != nil { - t.Fatal(err) - } - if len(result.Nodes) == 0 { - t.Error("intent-aware recall returned no nodes") - } - // Should find both decision and convention via causal chain - found := map[string]bool{} - for _, n := range result.Nodes { - found[n.Type] = true - } - t.Logf("Why query found types: %v", found) -} - -func TestPhase6DualStream(t *testing.T) { - eng, cleanup := setup(t) - defer cleanup() - - ds := ingest.New(eng) - defer ds.Stop() - - // Fast path should return immediately - node, err := ds.Remember(context.Background(), engine.RememberInput{ - Type: "convention", Content: "Use jose not jsonwebtoken", Scope: "project", - }) - if err != nil { - t.Fatal(err) - } - if node.ID == "" { - t.Error("dual stream: empty node ID") - } - - // Second remember should create temporal backbone edge - node2, err := ds.Remember(context.Background(), engine.RememberInput{ - Type: "decision", Content: "Chose RS256 for JWT", Scope: "project", - }) - if err != nil { - t.Fatal(err) - } - if node2.ID == "" { - t.Error("dual stream: second node empty ID") - } - - // Give slow path time to run and release DB lock - // Retry up to 500ms (slow path runs async) - var hasTemporalEdge bool - for i := 0; i < 10; i++ { - time.Sleep(50 * time.Millisecond) - edges, _ := eng.Store().GetEdgesFrom(context.Background(), node.ID) - for _, e := range edges { - if e.ToID == node2.ID && e.Type == "learned_in" { - hasTemporalEdge = true - } - } - if hasTemporalEdge { - break - } - } - if !hasTemporalEdge { - t.Error("dual stream: temporal backbone edge not created within 500ms") - } -} - -func TestPrivacyFilter(t *testing.T) { - eng, cleanup := setup(t) - defer cleanup() - - // Store content with secrets — should be stripped - node, _ := eng.Remember(context.Background(), engine.RememberInput{ - Type: "convention", - Content: "Use API key sk-1234567890abcdefghijklmnop for auth and AKIA1234567890ABCDEF for AWS", - Scope: "project", - }) - if strings.Contains(node.Content, "sk-1234567890") { - t.Error("privacy: API key not stripped") - } - if strings.Contains(node.Content, "AKIA1234567890") { - t.Error("privacy: AWS key not stripped") - } - if !strings.Contains(node.Content, "[REDACTED]") { - t.Error("privacy: expected [REDACTED] placeholder") - } -} - -func TestUtilsShortID(t *testing.T) { - cases := []struct{ input, expected string }{ - {"abcdefghijklmnop", "abcdefgh"}, - {"short", "short"}, - {"12345678", "12345678"}, - {"", ""}, - {"ab", "ab"}, - } - for _, c := range cases { - got := utils.ShortID(c.input) - if got != c.expected { - t.Errorf("ShortID(%q) = %q, want %q", c.input, got, c.expected) - } - } -} - -func TestEdgeCaseEmptyRecall(t *testing.T) { - eng, cleanup := setup(t) - defer cleanup() - - // Recall on empty DB should return empty, not error - result, err := eng.Recall(context.Background(), engine.RecallOpts{Query: "nonexistent", Limit: 5}) - if err != nil { - t.Fatal(err) - } - if len(result.Nodes) != 0 { - t.Errorf("empty recall: expected 0 nodes, got %d", len(result.Nodes)) - } -} - -func TestEdgeCaseContextEmpty(t *testing.T) { - eng, cleanup := setup(t) - defer cleanup() - - // Context on empty DB should return empty, not error - result, err := eng.Context(context.Background(), "") - if err != nil { - t.Fatal(err) - } - if result == nil { - t.Error("context: should return empty result, not nil") - } -} - -func TestEdgeCaseForgetNonexistent(t *testing.T) { - eng, cleanup := setup(t) - defer cleanup() - - // Forget nonexistent node should error gracefully - err := eng.Forget(context.Background(), "nonexistent-id-12345678") - if err == nil { - t.Error("forget: should error on nonexistent node") - } -} - -func TestProfileMerge(t *testing.T) { - a := &profile.Profile{ - Project: "test", - Static: []string{"Use jose", "Use NATS"}, - Dynamic: []string{"[task] rate limiting"}, - Stack: []string{"TypeScript", "NATS"}, - } - b := &profile.Profile{ - Static: []string{"Prefer tabs", "Use jose"}, // "Use jose" is duplicate - Dynamic: []string{"[bug] auth race"}, - Stack: []string{"PostgreSQL", "NATS"}, // "NATS" is duplicate - } - merged := profile.Merge(a, b) - // Static should be deduped - if len(merged.Static) != 3 { // jose, NATS, tabs - t.Errorf("merge: expected 3 static, got %d: %v", len(merged.Static), merged.Static) - } - // Stack should be deduped - if len(merged.Stack) != 3 { // TypeScript, NATS, PostgreSQL - t.Errorf("merge: expected 3 stack, got %d: %v", len(merged.Stack), merged.Stack) - } - // Dynamic should be combined (not deduped) - if len(merged.Dynamic) != 2 { - t.Errorf("merge: expected 2 dynamic, got %d", len(merged.Dynamic)) - } -} - -func TestMultipleRememberAndRecall(t *testing.T) { - eng, cleanup := setup(t) - defer cleanup() - - // Store 20 memories of different types - types := []string{"convention", "decision", "bug", "spec", "task"} - for i := 0; i < 20; i++ { - eng.Remember(context.Background(), engine.RememberInput{ - Type: types[i%len(types)], - Content: fmt.Sprintf("Memory item %d about topic %d", i, i%5), - Scope: "project", - }) - } - - // Recall should find results - result, err := eng.Recall(context.Background(), engine.RecallOpts{Query: "topic", Limit: 10}) - if err != nil { - t.Fatal(err) - } - if len(result.Nodes) == 0 { - t.Error("bulk recall: expected nodes") - } - t.Logf("Stored 20, recalled %d nodes", len(result.Nodes)) - - // Status should show correct counts - st, _ := eng.Status(context.Background(), "") - if st.Nodes < 20 { - t.Errorf("status: expected ≥20 nodes, got %d", st.Nodes) - } -} - -func TestTLSCertGeneration(t *testing.T) { - dir := t.TempDir() - cfg := yaadtls.Config{Enabled: true} - - tlsCfg, err := yaadtls.TLSConfig(cfg, dir) - if err != nil { - t.Fatal(err) - } - if tlsCfg == nil { - t.Fatal("tls: nil config returned") - } - if len(tlsCfg.Certificates) == 0 { - t.Error("tls: no certificates generated") - } - - // Verify cert files were created - if _, err := os.Stat(filepath.Join(dir, "cert.pem")); err != nil { - t.Error("tls: cert.pem not created") - } - if _, err := os.Stat(filepath.Join(dir, "key.pem")); err != nil { - t.Error("tls: key.pem not created") - } -} - -func TestConfigDefaults(t *testing.T) { - cfg := config.Default() - if cfg.Server.Port != 3456 { - t.Errorf("config: expected port 3456, got %d", cfg.Server.Port) - } - if cfg.Decay.HalfLifeDays != 30 { - t.Errorf("config: expected half_life 30, got %d", cfg.Decay.HalfLifeDays) - } -} - -func TestStorageCreateAndQuery(t *testing.T) { - dir := t.TempDir() - store, err := storage.NewStore(filepath.Join(dir, "test.db")) - if err != nil { - t.Fatal(err) - } - defer func() { store.Close() }() - - ctx := context.Background() - - // Create node - node := &storage.Node{ - ID: "test-node-1", Type: "convention", Content: "Test content", - ContentHash: "hash1", Scope: "project", Tier: 1, Confidence: 1.0, Version: 1, - } - if err := store.CreateNode(ctx, node); err != nil { - t.Fatal(err) - } - - // Get node - got, err := store.GetNode(ctx, "test-node-1") - if err != nil { - t.Fatal(err) - } - if got.Content != "Test content" { - t.Errorf("storage: expected 'Test content', got '%s'", got.Content) - } - - // Create edge - edge := &storage.Edge{ - ID: "test-edge-1", FromID: "test-node-1", ToID: "test-node-1", - Type: "relates_to", Acyclic: false, Weight: 1.0, - } - if err := store.CreateEdge(ctx, edge); err != nil { - t.Fatal(err) - } - - // Get neighbors - neighbors, err := store.GetNeighbors(ctx, "test-node-1") - if err != nil { - t.Fatal(err) - } - if len(neighbors) == 0 { - t.Error("storage: expected neighbors") - } - - // Version history - if err := store.SaveVersion(ctx, "test-node-1", "old content", "test", "test update"); err != nil { - t.Fatal(err) - } - versions, err := store.GetVersions(ctx, "test-node-1") - if err != nil { - t.Fatal(err) - } - if len(versions) == 0 { - t.Error("storage: expected version history") - } -} - -func TestRESTAPI(t *testing.T) { - eng, cleanup := setup(t) - defer cleanup() - - mux := http.NewServeMux() - rest := server.NewRESTServer(eng, "") - rest.RegisterRoutes(mux) - ts := httptest.NewServer(mux) - defer ts.Close() - - // POST /yaad/remember - body, _ := json.Marshal(engine.RememberInput{ - Type: "convention", Content: "Always use TypeScript strict mode", Scope: "project", - }) - resp, err := http.Post(ts.URL+"/yaad/remember", "application/json", bytes.NewReader(body)) - if err != nil { - t.Fatal(err) - } - if resp.StatusCode != 201 { - t.Errorf("remember: expected 201, got %d", resp.StatusCode) - } - var node storage.Node - json.NewDecoder(resp.Body).Decode(&node) - resp.Body.Close() - if node.ID == "" { - t.Error("remember: empty node ID") - } - - // POST /yaad/recall - body, _ = json.Marshal(engine.RecallOpts{Query: "TypeScript", Limit: 5}) - resp, err = http.Post(ts.URL+"/yaad/recall", "application/json", bytes.NewReader(body)) - if err != nil { - t.Fatal(err) - } - if resp.StatusCode != 200 { - t.Errorf("recall: expected 200, got %d", resp.StatusCode) - } - var result engine.RecallResult - json.NewDecoder(resp.Body).Decode(&result) - resp.Body.Close() - if len(result.Nodes) == 0 { - t.Error("recall: expected nodes, got none") - } - - // GET /yaad/health - resp, _ = http.Get(ts.URL + "/yaad/health") - if resp.StatusCode != 200 { - t.Errorf("health: expected 200, got %d", resp.StatusCode) - } - resp.Body.Close() - - // GET /yaad/context - resp, _ = http.Get(ts.URL + "/yaad/context") - if resp.StatusCode != 200 { - t.Errorf("context: expected 200, got %d", resp.StatusCode) - } - resp.Body.Close() -} - -// TestConcurrentSQLiteAccess verifies that concurrent Remember and Recall operations -// against the real SQLite backend do not race or corrupt data. -func TestConcurrentSQLiteAccess(t *testing.T) { - eng, cleanup := setup(t) - defer cleanup() - - var wg sync.WaitGroup - numWriters := 5 - numReaders := 5 - opsPerGoroutine := 10 - - // Writers - for i := 0; i < numWriters; i++ { - wg.Add(1) - go func(idx int) { - defer wg.Done() - for j := 0; j < opsPerGoroutine; j++ { - _, err := eng.Remember(context.Background(), engine.RememberInput{ - Type: "convention", - Content: fmt.Sprintf("writer-%d-op-%d", idx, j), - Scope: "project", - Project: "concurrent-test", - }) - if err != nil { - // Under CI load, occasional SQLITE_BUSY is expected even with - // _busy_timeout. Skip individual operations rather than failing. - if strings.Contains(err.Error(), "database is locked") { - time.Sleep(10 * time.Millisecond) - continue - } - t.Errorf("writer %d op %d failed: %v", idx, j, err) - } - } - }(i) - } - - // Readers - for i := 0; i < numReaders; i++ { - wg.Add(1) - go func(idx int) { - defer wg.Done() - for j := 0; j < opsPerGoroutine; j++ { - _, err := eng.Recall(context.Background(), engine.RecallOpts{ - Query: "writer", - Project: "concurrent-test", - Limit: 10, - }) - if err != nil { - t.Errorf("reader %d op %d failed: %v", idx, j, err) - } - } - }(i) - } - - wg.Wait() - - st, err := eng.Status(context.Background(), "concurrent-test") - if err != nil { - t.Fatalf("status failed: %v", err) - } - expectedNodes := numWriters * opsPerGoroutine - if st.Nodes < expectedNodes { - t.Errorf("expected at least %d nodes, got %d", expectedNodes, st.Nodes) - } -} diff --git a/internal/server/mcp_concurrency_rest_test.go b/internal/server/mcp_concurrency_rest_test.go new file mode 100644 index 0000000..7a7cf5a --- /dev/null +++ b/internal/server/mcp_concurrency_rest_test.go @@ -0,0 +1,479 @@ +// This file is part of package server tests. It holds the concurrency +// (concurrent remember/forget/link) tests and the REST API-key auth +// tests moved verbatim out of mcp_test.go for readability; behavior is +// unchanged. + +package server + +import ( + "bytes" + "context" + "fmt" + "net/http" + "net/http/httptest" + "runtime" + "strings" + "sync" + "testing" + "time" + + "github.com/GrayCodeAI/yaad/storage" +) + +func TestMCPConcurrentRemember(t *testing.T) { + srv, _, cleanup := setupMCPServer(t) + defer cleanup() + + const n = 10 + var wg sync.WaitGroup + var mu sync.Mutex + successes := 0 + sqliteBusy := 0 + + for i := 0; i < n; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + ctx := context.Background() + req := toolRequest("yaad_remember", map[string]any{ + "content": fmt.Sprintf("concurrent memory %d", idx), + "type": "decision", + }) + _, err := srv.handleRemember(ctx, req) + mu.Lock() + defer mu.Unlock() + if err != nil { + if strings.Contains(err.Error(), "SQLITE_BUSY") { + sqliteBusy++ + } else { + t.Errorf("goroutine %d unexpected error: %v", idx, err) + } + } else { + successes++ + } + }(i) + } + + wg.Wait() + + // With SQLite WAL mode, some SQLITE_BUSY is expected under heavy concurrency. + // The engine's write mutex serializes Remember calls, but SelfLink runs + // outside the lock, so contention is still possible. + if successes == 0 { + t.Fatal("expected at least some concurrent remembers to succeed") + } + t.Logf("concurrent remember: %d/%d succeeded, %d SQLITE_BUSY", successes, n, sqliteBusy) + + // Verify at least the successful nodes exist. + // Retry ListNodes briefly in case of SQLITE_BUSY from SelfLink cleanup. + runtime.Gosched() + time.Sleep(20 * time.Millisecond) + ctx := context.Background() + nodes, err := srv.eng.Store().ListNodes(ctx, storage.NodeFilter{}) + if err != nil { + t.Logf("ListNodes SQLITE_BUSY (expected after concurrent writes): %v", err) + return + } + if len(nodes) < successes { + t.Errorf("expected at least %d nodes, got %d", successes, len(nodes)) + } +} + +func TestMCPConcurrentRememberAndForget(t *testing.T) { + srv, _, cleanup := setupMCPServer(t) + defer cleanup() + + ctx := context.Background() + // Pre-create some nodes to forget. + var ids []string + for i := 0; i < 5; i++ { + ids = append(ids, rememberAndID(t, srv, fmt.Sprintf("pre-existing %d", i), "decision")) + } + runtime.Gosched() + time.Sleep(10 * time.Millisecond) + + const n = 5 + var wg sync.WaitGroup + var mu sync.Mutex + nonBusyErrors := 0 + + // Concurrent remembers. + for i := 0; i < n; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + req := toolRequest("yaad_remember", map[string]any{ + "content": fmt.Sprintf("new memory %d", idx), + "type": "convention", + }) + _, err := srv.handleRemember(ctx, req) + if err != nil && !strings.Contains(err.Error(), "SQLITE_BUSY") { + mu.Lock() + nonBusyErrors++ + mu.Unlock() + t.Errorf("remember %d unexpected error: %v", idx, err) + } + }(i) + } + + // Concurrent forgets. + for i, id := range ids { + wg.Add(1) + go func(idx int, nodeID string) { + defer wg.Done() + req := toolRequest("yaad_forget", map[string]any{"id": nodeID}) + _, err := srv.handleForget(ctx, req) + if err != nil && !strings.Contains(err.Error(), "SQLITE_BUSY") { + mu.Lock() + nonBusyErrors++ + mu.Unlock() + t.Errorf("forget %d unexpected error: %v", idx, err) + } + }(i, id) + } + + wg.Wait() + + if nonBusyErrors > 0 { + t.Errorf("got %d non-SQLITE_BUSY errors from concurrent operations", nonBusyErrors) + } +} + +func TestMCPConcurrentLinkAndRecall(t *testing.T) { + srv, _, cleanup := setupMCPServer(t) + defer cleanup() + + ctx := context.Background() + idA := rememberAndID(t, srv, "link source", "decision") + idB := rememberAndID(t, srv, "link target", "convention") + runtime.Gosched() + time.Sleep(5 * time.Millisecond) + + var wg sync.WaitGroup + var mu sync.Mutex + nonBusyErrors := 0 + + // Concurrent links. + for i := 0; i < 5; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + req := toolRequest("yaad_link", map[string]any{ + "from_id": idA, + "to_id": idB, + "type": "relates_to", + }) + // Linking the same pair may fail due to duplicate edge ID; that's OK. + srv.handleLink(ctx, req) + }(i) + } + + // Concurrent recalls. + for i := 0; i < 5; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + req := toolRequest("yaad_recall", map[string]any{ + "query": "link", + "limit": float64(5), + }) + _, err := srv.handleRecall(ctx, req) + if err != nil && !strings.Contains(err.Error(), "SQLITE_BUSY") { + mu.Lock() + nonBusyErrors++ + mu.Unlock() + t.Errorf("recall %d unexpected error: %v", idx, err) + } + }(i) + } + + wg.Wait() + + if nonBusyErrors > 0 { + t.Errorf("got %d non-SQLITE_BUSY errors from concurrent link/recall", nonBusyErrors) + } +} + +// --------------------------------------------------------------------------- +// 6. Authentication — API key validation on REST server +// --------------------------------------------------------------------------- + +func TestRESTAPIKeyRequired(t *testing.T) { + _, eng, cleanup := setupMCPServer(t) + defer cleanup() + + srv := NewRESTServer(eng, "") + srv.WithAPIKey("test-secret-key-12345") + + mux := http.NewServeMux() + srv.RegisterRoutes(mux) + wrapped := srv.withMiddleware(mux) + + // Request without API key should be unauthorized. + body := `{"content":"test","type":"decision"}` + req := httptest.NewRequest("POST", "/yaad/remember", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + wrapped.ServeHTTP(rr, req) + + if rr.Code != 401 { + t.Errorf("no API key: expected 401, got %d: %s", rr.Code, rr.Body.String()) + } +} + +func TestRESTAPIKeyBearerAuth(t *testing.T) { + _, eng, cleanup := setupMCPServer(t) + defer cleanup() + + srv := NewRESTServer(eng, "") + srv.WithAPIKey("test-secret-key-12345") + + mux := http.NewServeMux() + srv.RegisterRoutes(mux) + wrapped := srv.withMiddleware(mux) + + // Request with correct Bearer token should succeed. + body := `{"content":"authenticated memory","type":"decision"}` + req := httptest.NewRequest("POST", "/yaad/remember", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer test-secret-key-12345") + rr := httptest.NewRecorder() + wrapped.ServeHTTP(rr, req) + + if rr.Code != 201 { + t.Errorf("valid Bearer: expected 201, got %d: %s", rr.Code, rr.Body.String()) + } +} + +func TestRESTAPIKeyXAPIKeyHeader(t *testing.T) { + _, eng, cleanup := setupMCPServer(t) + defer cleanup() + + srv := NewRESTServer(eng, "") + srv.WithAPIKey("test-secret-key-12345") + + mux := http.NewServeMux() + srv.RegisterRoutes(mux) + wrapped := srv.withMiddleware(mux) + + // Request with X-API-Key header should succeed. + body := `{"content":"x-api-key memory","type":"decision"}` + req := httptest.NewRequest("POST", "/yaad/remember", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-API-Key", "test-secret-key-12345") + rr := httptest.NewRecorder() + wrapped.ServeHTTP(rr, req) + + if rr.Code != 201 { + t.Errorf("valid X-API-Key: expected 201, got %d: %s", rr.Code, rr.Body.String()) + } +} + +func TestRESTAPIKeyWrongKey(t *testing.T) { + _, eng, cleanup := setupMCPServer(t) + defer cleanup() + + srv := NewRESTServer(eng, "") + srv.WithAPIKey("correct-key") + + mux := http.NewServeMux() + srv.RegisterRoutes(mux) + wrapped := srv.withMiddleware(mux) + + body := `{"content":"wrong key","type":"decision"}` + req := httptest.NewRequest("POST", "/yaad/remember", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer wrong-key") + rr := httptest.NewRecorder() + wrapped.ServeHTTP(rr, req) + + if rr.Code != 401 { + t.Errorf("wrong key: expected 401, got %d: %s", rr.Code, rr.Body.String()) + } +} + +func TestRESTHealthSkipsAuth(t *testing.T) { + _, eng, cleanup := setupMCPServer(t) + defer cleanup() + + srv := NewRESTServer(eng, "") + srv.WithAPIKey("test-secret-key-12345") + + mux := http.NewServeMux() + srv.RegisterRoutes(mux) + wrapped := srv.withMiddleware(mux) + + // Health endpoint should not require auth. + req := httptest.NewRequest("GET", "/yaad/health", nil) + rr := httptest.NewRecorder() + wrapped.ServeHTTP(rr, req) + + if rr.Code != 200 { + t.Errorf("health no auth: expected 200, got %d: %s", rr.Code, rr.Body.String()) + } +} + +func TestRESTNoAPIKeyAllowsAll(t *testing.T) { + _, eng, cleanup := setupMCPServer(t) + defer cleanup() + + srv := NewRESTServer(eng, "") + // No WithAPIKey — all requests should be allowed. + + mux := http.NewServeMux() + srv.RegisterRoutes(mux) + + body := `{"content":"no auth needed","type":"decision"}` + req := httptest.NewRequest("POST", "/yaad/remember", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + mux.ServeHTTP(rr, req) + + if rr.Code != 201 { + t.Errorf("no API key configured: expected 201, got %d: %s", rr.Code, rr.Body.String()) + } +} + +// --------------------------------------------------------------------------- +// Additional edge cases +// --------------------------------------------------------------------------- + +func TestMCPSessionRecapNoSessions(t *testing.T) { + srv, _, cleanup := setupMCPServer(t) + defer cleanup() + + ctx := context.Background() + req := toolRequest("yaad_session_recap", map[string]any{}) + res, err := srv.handleSessionRecap(ctx, req) + if err != nil { + t.Fatalf("handleSessionRecap: %v", err) + } + text := textContent(res) + if !strings.Contains(text, "No previous sessions") { + t.Errorf("expected 'No previous sessions', got %q", text) + } +} + +func TestMCPCompact(t *testing.T) { + srv, _, cleanup := setupMCPServer(t) + defer cleanup() + + ctx := context.Background() + req := toolRequest("yaad_compact", map[string]any{}) + res, err := srv.handleCompact(ctx, req) + if err != nil { + t.Fatalf("handleCompact: %v", err) + } + text := textContent(res) + if !strings.Contains(text, "compacted") { + t.Errorf("expected 'compacted' in result, got %q", text) + } +} + +func TestMCPProactive(t *testing.T) { + srv, _, cleanup := setupMCPServer(t) + defer cleanup() + + ctx := context.Background() + rememberAndID(t, srv, "proactive test memory", "decision") + + req := toolRequest("yaad_proactive", map[string]any{"budget": float64(500)}) + res, err := srv.handleProactive(ctx, req) + if err != nil { + t.Fatalf("handleProactive: %v", err) + } + if res == nil { + t.Fatal("expected non-nil result") + } +} + +func TestMCPMentalModel(t *testing.T) { + srv, _, cleanup := setupMCPServer(t) + defer cleanup() + + ctx := context.Background() + rememberAndID(t, srv, "We use Go for backend services", "convention") + + req := toolRequest("yaad_mental_model", map[string]any{}) + res, err := srv.handleMentalModel(ctx, req) + if err != nil { + t.Fatalf("handleMentalModel: %v", err) + } + if res == nil { + t.Fatal("expected non-nil result") + } +} + +func TestMCPProfile(t *testing.T) { + srv, _, cleanup := setupMCPServer(t) + defer cleanup() + + ctx := context.Background() + rememberAndID(t, srv, "developer prefers vim", "preference") + + req := toolRequest("yaad_profile", map[string]any{}) + res, err := srv.handleProfile(ctx, req) + if err != nil { + t.Fatalf("handleProfile: %v", err) + } + if res == nil { + t.Fatal("expected non-nil result") + } +} + +// TestMCPToolRegistration verifies that all expected tools are registered +// and that the tool count matches expectations. +func TestMCPToolRegistration(t *testing.T) { + srv, _, cleanup := setupMCPServer(t) + defer cleanup() + + // The MCP server registers tools via AddTool. We verify by checking that + // calling each known tool handler does not panic and that the server struct + // was constructed successfully. + if srv.server == nil { + t.Fatal("expected non-nil mcp-go server") + } + if srv.eng == nil { + t.Fatal("expected non-nil engine") + } +} + +// TestMCPRoundTripRememberRecall stores a memory and verifies it can be found +// via recall, testing the full create-search pipeline. +func TestMCPRoundTripRememberRecall(t *testing.T) { + srv, _, cleanup := setupMCPServer(t) + defer cleanup() + + ctx := context.Background() + + // Store a distinctive memory. + _, err := srv.handleRemember(ctx, toolRequest("yaad_remember", map[string]any{ + "content": "Hawk uses SQLite for persistent graph storage with WAL mode", + "type": "convention", + "tags": "sqlite,storage,architecture", + })) + if err != nil { + t.Fatalf("remember: %v", err) + } + // Allow SelfLink async edge writes to complete. + runtime.Gosched() + time.Sleep(20 * time.Millisecond) + + // Recall it. + res, err := srv.handleRecall(ctx, toolRequest("yaad_recall", map[string]any{ + "query": "SQLite persistent storage", + "depth": float64(2), + "limit": float64(10), + })) + if err != nil { + t.Fatalf("recall: %v", err) + } + + text := textContent(res) + if text == "" || text == "null" { + t.Fatal("expected recall to find the stored memory") + } + if !strings.Contains(text, "SQLite") { + t.Errorf("recall result should contain 'SQLite', got %q", text[:min(200, len(text))]) + } +} diff --git a/internal/server/mcp_graph_test.go b/internal/server/mcp_graph_test.go new file mode 100644 index 0000000..1c80228 --- /dev/null +++ b/internal/server/mcp_graph_test.go @@ -0,0 +1,476 @@ +// This file is part of package server tests. It holds the MCP link, +// subgraph, status, session, skill, feedback, context, and +// clamp/cancellation tests moved verbatim out of mcp_test.go for +// readability; behavior is unchanged. + +package server + +import ( + "context" + "encoding/json" + "runtime" + "strings" + "testing" + "time" + + "github.com/GrayCodeAI/yaad/storage" +) + +func TestMCPLinkAndUnlink(t *testing.T) { + srv, _, cleanup := setupMCPServer(t) + defer cleanup() + + ctx := context.Background() + idA := rememberAndID(t, srv, "node A", "decision") + idB := rememberAndID(t, srv, "node B", "convention") + runtime.Gosched() + time.Sleep(5 * time.Millisecond) + + // Link. + req := toolRequest("yaad_link", map[string]any{ + "from_id": idA, + "to_id": idB, + "type": "relates_to", + }) + res, err := srv.handleLink(ctx, req) + if err != nil { + t.Fatalf("handleLink: %v", err) + } + text := textContent(res) + if !strings.Contains(text, idA) || !strings.Contains(text, idB) { + t.Errorf("link result should contain node IDs, got %q", text) + } + + // Parse edge to get ID for unlink. + var edge storage.Edge + json.Unmarshal([]byte(text), &edge) + + // Unlink. + req = toolRequest("yaad_unlink", map[string]any{"id": edge.ID}) + res, err = srv.handleUnlink(ctx, req) + if err != nil { + t.Fatalf("handleUnlink: %v", err) + } + if textContent(res) != "unlinked" { + t.Errorf("unlink result = %q, want 'unlinked'", textContent(res)) + } +} + +func TestMCPLinkInvalidEdgeType(t *testing.T) { + srv, _, cleanup := setupMCPServer(t) + defer cleanup() + + ctx := context.Background() + idA := rememberAndID(t, srv, "A", "decision") + idB := rememberAndID(t, srv, "B", "decision") + runtime.Gosched() + time.Sleep(5 * time.Millisecond) + + req := toolRequest("yaad_link", map[string]any{ + "from_id": idA, + "to_id": idB, + "type": "invalid_edge_xyz", + }) + _, err := srv.handleLink(ctx, req) + if err == nil { + t.Fatal("expected error for invalid edge type") + } + if !strings.Contains(err.Error(), "invalid edge type") { + t.Errorf("error = %q, want 'invalid edge type'", err.Error()) + } +} + +func TestMCPLinkMissingEdgeType(t *testing.T) { + srv, _, cleanup := setupMCPServer(t) + defer cleanup() + + ctx := context.Background() + idA := rememberAndID(t, srv, "A", "decision") + idB := rememberAndID(t, srv, "B", "decision") + runtime.Gosched() + time.Sleep(5 * time.Millisecond) + + req := toolRequest("yaad_link", map[string]any{ + "from_id": idA, + "to_id": idB, + }) + _, err := srv.handleLink(ctx, req) + if err == nil { + t.Fatal("expected error for missing edge type") + } +} + +func TestMCPSubgraph(t *testing.T) { + srv, _, cleanup := setupMCPServer(t) + defer cleanup() + + ctx := context.Background() + idA := rememberAndID(t, srv, "center node", "decision") + idB := rememberAndID(t, srv, "neighbor node", "convention") + runtime.Gosched() + time.Sleep(5 * time.Millisecond) + + // Create an edge so BFS finds the neighbor. + edgeReq := toolRequest("yaad_link", map[string]any{ + "from_id": idA, + "to_id": idB, + "type": "relates_to", + }) + srv.handleLink(ctx, edgeReq) + runtime.Gosched() + time.Sleep(5 * time.Millisecond) + + req := toolRequest("yaad_subgraph", map[string]any{ + "id": idA, + "depth": float64(2), + }) + res, err := srv.handleSubgraph(ctx, req) + if err != nil { + t.Fatalf("handleSubgraph: %v", err) + } + text := textContent(res) + if text == "" || text == "null" { + t.Error("expected non-empty subgraph") + } +} + +func TestMCPStatus(t *testing.T) { + srv, _, cleanup := setupMCPServer(t) + defer cleanup() + + ctx := context.Background() + rememberAndID(t, srv, "status test node", "decision") + + req := toolRequest("yaad_status", map[string]any{}) + res, err := srv.handleStatus(ctx, req) + if err != nil { + t.Fatalf("handleStatus: %v", err) + } + text := textContent(res) + if !strings.Contains(text, "Nodes") { + t.Errorf("status result should contain 'Nodes', got %q", text) + } +} + +func TestMCPSessions(t *testing.T) { + srv, _, cleanup := setupMCPServer(t) + defer cleanup() + + ctx := context.Background() + req := toolRequest("yaad_sessions", map[string]any{"limit": float64(5)}) + res, err := srv.handleSessions(ctx, req) + if err != nil { + t.Fatalf("handleSessions: %v", err) + } + if res == nil { + t.Fatal("expected non-nil result") + } +} + +func TestMCPSkillStoreAndGet(t *testing.T) { + srv, _, cleanup := setupMCPServer(t) + defer cleanup() + + ctx := context.Background() + steps, _ := json.Marshal([]string{"Step 1: Write code", "Step 2: Write tests", "Step 3: Submit PR"}) + + // Store skill. + req := toolRequest("yaad_skill_store", map[string]any{ + "name": "code-review-workflow", + "description": "Standard code review workflow", + "steps": string(steps), + }) + res, err := srv.handleSkillStore(ctx, req) + if err != nil { + t.Fatalf("handleSkillStore: %v", err) + } + text := textContent(res) + if !strings.Contains(text, "code-review-workflow") { + t.Errorf("skill store result = %q, want to contain skill name", text) + } +} + +func TestMCPSkillStoreInvalidSteps(t *testing.T) { + srv, _, cleanup := setupMCPServer(t) + defer cleanup() + + ctx := context.Background() + req := toolRequest("yaad_skill_store", map[string]any{ + "name": "bad-skill", + "description": "invalid steps", + "steps": "not a json array", + }) + _, err := srv.handleSkillStore(ctx, req) + if err == nil { + t.Fatal("expected error for invalid steps JSON") + } + if !strings.Contains(err.Error(), "JSON array") { + t.Errorf("error = %q, want 'JSON array'", err.Error()) + } +} + +func TestMCPFeedbackApprove(t *testing.T) { + srv, _, cleanup := setupMCPServer(t) + defer cleanup() + + ctx := context.Background() + id := rememberAndID(t, srv, "pending memory", "decision") + + req := toolRequest("yaad_feedback", map[string]any{ + "id": id, + "action": "approve", + }) + res, err := srv.handleFeedback(ctx, req) + if err != nil { + t.Fatalf("handleFeedback approve: %v", err) + } + if !strings.Contains(textContent(res), "approve") { + t.Errorf("expected 'approve' in result, got %q", textContent(res)) + } +} + +func TestMCPFeedbackEdit(t *testing.T) { + srv, _, cleanup := setupMCPServer(t) + defer cleanup() + + ctx := context.Background() + id := rememberAndID(t, srv, "memory to edit", "decision") + + req := toolRequest("yaad_feedback", map[string]any{ + "id": id, + "action": "edit", + "content": "edited content", + }) + res, err := srv.handleFeedback(ctx, req) + if err != nil { + t.Fatalf("handleFeedback edit: %v", err) + } + if !strings.Contains(textContent(res), "edit") { + t.Errorf("expected 'edit' in result, got %q", textContent(res)) + } +} + +func TestMCPFeedbackDiscard(t *testing.T) { + srv, _, cleanup := setupMCPServer(t) + defer cleanup() + + ctx := context.Background() + id := rememberAndID(t, srv, "memory to discard", "decision") + + req := toolRequest("yaad_feedback", map[string]any{ + "id": id, + "action": "discard", + }) + res, err := srv.handleFeedback(ctx, req) + if err != nil { + t.Fatalf("handleFeedback discard: %v", err) + } + if !strings.Contains(textContent(res), "discard") { + t.Errorf("expected 'discard' in result, got %q", textContent(res)) + } +} + +func TestMCPFeedbackInvalidAction(t *testing.T) { + srv, _, cleanup := setupMCPServer(t) + defer cleanup() + + ctx := context.Background() + id := rememberAndID(t, srv, "some memory", "decision") + + req := toolRequest("yaad_feedback", map[string]any{ + "id": id, + "action": "invalid_action", + }) + _, err := srv.handleFeedback(ctx, req) + if err == nil { + t.Fatal("expected error for invalid feedback action") + } +} + +// --------------------------------------------------------------------------- +// 4. Error handling — malformed requests, edge cases +// --------------------------------------------------------------------------- + +func TestMCPRecallEmptyQuery(t *testing.T) { + srv, _, cleanup := setupMCPServer(t) + defer cleanup() + + ctx := context.Background() + // Empty query should not error — it should list nodes. + req := toolRequest("yaad_recall", map[string]any{ + "query": "", + "limit": float64(5), + }) + res, err := srv.handleRecall(ctx, req) + if err != nil { + t.Fatalf("recall with empty query should not error: %v", err) + } + if res == nil { + t.Fatal("expected non-nil result") + } +} + +func TestMCPContext(t *testing.T) { + srv, _, cleanup := setupMCPServer(t) + defer cleanup() + + ctx := context.Background() + rememberAndID(t, srv, "context test", "decision") + + req := toolRequest("yaad_context", map[string]any{}) + res, err := srv.handleContext(ctx, req) + if err != nil { + t.Fatalf("handleContext: %v", err) + } + if res == nil { + t.Fatal("expected non-nil result") + } +} + +func TestMCPContextCanceled(t *testing.T) { + srv, _, cleanup := setupMCPServer(t) + defer cleanup() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + + req := toolRequest("yaad_context", map[string]any{}) + _, err := srv.handleContext(ctx, req) + if err == nil { + t.Fatal("expected error from canceled context") + } +} + +func TestMCPRecallCanceledContext(t *testing.T) { + srv, _, cleanup := setupMCPServer(t) + defer cleanup() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + req := toolRequest("yaad_recall", map[string]any{"query": "test"}) + _, err := srv.handleRecall(ctx, req) + if err == nil { + t.Fatal("expected error from canceled context") + } +} + +func TestMCPRememberCanceledContext(t *testing.T) { + srv, _, cleanup := setupMCPServer(t) + defer cleanup() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + req := toolRequest("yaad_remember", map[string]any{"content": "test"}) + _, err := srv.handleRemember(ctx, req) + if err == nil { + t.Fatal("expected error from canceled context") + } +} + +func TestMCPForgetCanceledContext(t *testing.T) { + srv, _, cleanup := setupMCPServer(t) + defer cleanup() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + req := toolRequest("yaad_forget", map[string]any{"id": "any"}) + _, err := srv.handleForget(ctx, req) + if err == nil { + t.Fatal("expected error from canceled context") + } +} + +func TestMCPImpactCanceledContext(t *testing.T) { + srv, _, cleanup := setupMCPServer(t) + defer cleanup() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + req := toolRequest("yaad_impact", map[string]any{"file": "/some/file.go"}) + _, err := srv.handleImpact(ctx, req) + if err == nil { + t.Fatal("expected error from canceled context") + } +} + +func TestMCPSubgraphDepthClamp(t *testing.T) { + srv, _, cleanup := setupMCPServer(t) + defer cleanup() + + ctx := context.Background() + id := rememberAndID(t, srv, "depth test", "decision") + + // depth > 5 should be clamped to 2. + req := toolRequest("yaad_subgraph", map[string]any{ + "id": id, + "depth": float64(99), + }) + res, err := srv.handleSubgraph(ctx, req) + if err != nil { + t.Fatalf("subgraph with clamped depth should not error: %v", err) + } + if res == nil { + t.Fatal("expected non-nil result") + } +} + +func TestMCPImpactDepthClamp(t *testing.T) { + srv, _, cleanup := setupMCPServer(t) + defer cleanup() + + ctx := context.Background() + // depth > 5 should be clamped to 3. + req := toolRequest("yaad_impact", map[string]any{ + "file": "/some/file.go", + "depth": float64(99), + }) + res, err := srv.handleImpact(ctx, req) + if err != nil { + t.Fatalf("impact with clamped depth should not error: %v", err) + } + if res == nil { + t.Fatal("expected non-nil result") + } +} + +func TestMCPRememberContentTooLong(t *testing.T) { + srv, _, cleanup := setupMCPServer(t) + defer cleanup() + + ctx := context.Background() + longContent := strings.Repeat("x", 20000) + req := toolRequest("yaad_remember", map[string]any{ + "content": longContent, + }) + _, err := srv.handleRemember(ctx, req) + if err == nil { + t.Fatal("expected error for content exceeding max length") + } + if !strings.Contains(err.Error(), "max length") { + t.Errorf("error = %q, want 'max length'", err.Error()) + } +} + +func TestMCPVerifyWithoutIntegrity(t *testing.T) { + srv, _, cleanup := setupMCPServer(t) + defer cleanup() + + ctx := context.Background() + // Engine is created without WithIntegrity, so verify should fail. + req := toolRequest("yaad_verify", map[string]any{}) + _, err := srv.handleVerify(ctx, req) + if err == nil { + t.Fatal("expected error when integrity checker not configured") + } + if !strings.Contains(err.Error(), "integrity checker not configured") { + t.Errorf("error = %q, want 'integrity checker not configured'", err.Error()) + } +} + +// --------------------------------------------------------------------------- +// 5. Concurrency — parallel memory operations +// --------------------------------------------------------------------------- diff --git a/internal/server/mcp_test.go b/internal/server/mcp_test.go index 0402dc9..1bc9457 100644 --- a/internal/server/mcp_test.go +++ b/internal/server/mcp_test.go @@ -1,16 +1,11 @@ package server import ( - "bytes" "context" "encoding/json" - "fmt" - "net/http" - "net/http/httptest" "path/filepath" "runtime" "strings" - "sync" "testing" "time" @@ -92,6 +87,11 @@ func rememberAndID(t *testing.T, srv *MCPServer, content, typ string) string { // 1. Memory operations — Create, Read, Update, Delete // --------------------------------------------------------------------------- +// Note: MCP graph/session/feedback tests and concurrency/REST-auth tests +// were moved verbatim into mcp_graph_test.go and +// mcp_concurrency_rest_test.go for readability. This file keeps the shared +// test helpers plus the remember/forget/pin/recall test groups. + func TestMCPRememberCreatesNode(t *testing.T) { srv, _, cleanup := setupMCPServer(t) defer cleanup() @@ -441,923 +441,6 @@ func TestMCPHybridRecall(t *testing.T) { // 3. Tool handling — correct dispatch and edge linking // --------------------------------------------------------------------------- -func TestMCPLinkAndUnlink(t *testing.T) { - srv, _, cleanup := setupMCPServer(t) - defer cleanup() - - ctx := context.Background() - idA := rememberAndID(t, srv, "node A", "decision") - idB := rememberAndID(t, srv, "node B", "convention") - runtime.Gosched() - time.Sleep(5 * time.Millisecond) - - // Link. - req := toolRequest("yaad_link", map[string]any{ - "from_id": idA, - "to_id": idB, - "type": "relates_to", - }) - res, err := srv.handleLink(ctx, req) - if err != nil { - t.Fatalf("handleLink: %v", err) - } - text := textContent(res) - if !strings.Contains(text, idA) || !strings.Contains(text, idB) { - t.Errorf("link result should contain node IDs, got %q", text) - } - - // Parse edge to get ID for unlink. - var edge storage.Edge - json.Unmarshal([]byte(text), &edge) - - // Unlink. - req = toolRequest("yaad_unlink", map[string]any{"id": edge.ID}) - res, err = srv.handleUnlink(ctx, req) - if err != nil { - t.Fatalf("handleUnlink: %v", err) - } - if textContent(res) != "unlinked" { - t.Errorf("unlink result = %q, want 'unlinked'", textContent(res)) - } -} - -func TestMCPLinkInvalidEdgeType(t *testing.T) { - srv, _, cleanup := setupMCPServer(t) - defer cleanup() - - ctx := context.Background() - idA := rememberAndID(t, srv, "A", "decision") - idB := rememberAndID(t, srv, "B", "decision") - runtime.Gosched() - time.Sleep(5 * time.Millisecond) - - req := toolRequest("yaad_link", map[string]any{ - "from_id": idA, - "to_id": idB, - "type": "invalid_edge_xyz", - }) - _, err := srv.handleLink(ctx, req) - if err == nil { - t.Fatal("expected error for invalid edge type") - } - if !strings.Contains(err.Error(), "invalid edge type") { - t.Errorf("error = %q, want 'invalid edge type'", err.Error()) - } -} - -func TestMCPLinkMissingEdgeType(t *testing.T) { - srv, _, cleanup := setupMCPServer(t) - defer cleanup() - - ctx := context.Background() - idA := rememberAndID(t, srv, "A", "decision") - idB := rememberAndID(t, srv, "B", "decision") - runtime.Gosched() - time.Sleep(5 * time.Millisecond) - - req := toolRequest("yaad_link", map[string]any{ - "from_id": idA, - "to_id": idB, - }) - _, err := srv.handleLink(ctx, req) - if err == nil { - t.Fatal("expected error for missing edge type") - } -} - -func TestMCPSubgraph(t *testing.T) { - srv, _, cleanup := setupMCPServer(t) - defer cleanup() - - ctx := context.Background() - idA := rememberAndID(t, srv, "center node", "decision") - idB := rememberAndID(t, srv, "neighbor node", "convention") - runtime.Gosched() - time.Sleep(5 * time.Millisecond) - - // Create an edge so BFS finds the neighbor. - edgeReq := toolRequest("yaad_link", map[string]any{ - "from_id": idA, - "to_id": idB, - "type": "relates_to", - }) - srv.handleLink(ctx, edgeReq) - runtime.Gosched() - time.Sleep(5 * time.Millisecond) - - req := toolRequest("yaad_subgraph", map[string]any{ - "id": idA, - "depth": float64(2), - }) - res, err := srv.handleSubgraph(ctx, req) - if err != nil { - t.Fatalf("handleSubgraph: %v", err) - } - text := textContent(res) - if text == "" || text == "null" { - t.Error("expected non-empty subgraph") - } -} - -func TestMCPStatus(t *testing.T) { - srv, _, cleanup := setupMCPServer(t) - defer cleanup() - - ctx := context.Background() - rememberAndID(t, srv, "status test node", "decision") - - req := toolRequest("yaad_status", map[string]any{}) - res, err := srv.handleStatus(ctx, req) - if err != nil { - t.Fatalf("handleStatus: %v", err) - } - text := textContent(res) - if !strings.Contains(text, "Nodes") { - t.Errorf("status result should contain 'Nodes', got %q", text) - } -} - -func TestMCPSessions(t *testing.T) { - srv, _, cleanup := setupMCPServer(t) - defer cleanup() - - ctx := context.Background() - req := toolRequest("yaad_sessions", map[string]any{"limit": float64(5)}) - res, err := srv.handleSessions(ctx, req) - if err != nil { - t.Fatalf("handleSessions: %v", err) - } - if res == nil { - t.Fatal("expected non-nil result") - } -} - -func TestMCPSkillStoreAndGet(t *testing.T) { - srv, _, cleanup := setupMCPServer(t) - defer cleanup() - - ctx := context.Background() - steps, _ := json.Marshal([]string{"Step 1: Write code", "Step 2: Write tests", "Step 3: Submit PR"}) - - // Store skill. - req := toolRequest("yaad_skill_store", map[string]any{ - "name": "code-review-workflow", - "description": "Standard code review workflow", - "steps": string(steps), - }) - res, err := srv.handleSkillStore(ctx, req) - if err != nil { - t.Fatalf("handleSkillStore: %v", err) - } - text := textContent(res) - if !strings.Contains(text, "code-review-workflow") { - t.Errorf("skill store result = %q, want to contain skill name", text) - } -} - -func TestMCPSkillStoreInvalidSteps(t *testing.T) { - srv, _, cleanup := setupMCPServer(t) - defer cleanup() - - ctx := context.Background() - req := toolRequest("yaad_skill_store", map[string]any{ - "name": "bad-skill", - "description": "invalid steps", - "steps": "not a json array", - }) - _, err := srv.handleSkillStore(ctx, req) - if err == nil { - t.Fatal("expected error for invalid steps JSON") - } - if !strings.Contains(err.Error(), "JSON array") { - t.Errorf("error = %q, want 'JSON array'", err.Error()) - } -} - -func TestMCPFeedbackApprove(t *testing.T) { - srv, _, cleanup := setupMCPServer(t) - defer cleanup() - - ctx := context.Background() - id := rememberAndID(t, srv, "pending memory", "decision") - - req := toolRequest("yaad_feedback", map[string]any{ - "id": id, - "action": "approve", - }) - res, err := srv.handleFeedback(ctx, req) - if err != nil { - t.Fatalf("handleFeedback approve: %v", err) - } - if !strings.Contains(textContent(res), "approve") { - t.Errorf("expected 'approve' in result, got %q", textContent(res)) - } -} - -func TestMCPFeedbackEdit(t *testing.T) { - srv, _, cleanup := setupMCPServer(t) - defer cleanup() - - ctx := context.Background() - id := rememberAndID(t, srv, "memory to edit", "decision") - - req := toolRequest("yaad_feedback", map[string]any{ - "id": id, - "action": "edit", - "content": "edited content", - }) - res, err := srv.handleFeedback(ctx, req) - if err != nil { - t.Fatalf("handleFeedback edit: %v", err) - } - if !strings.Contains(textContent(res), "edit") { - t.Errorf("expected 'edit' in result, got %q", textContent(res)) - } -} - -func TestMCPFeedbackDiscard(t *testing.T) { - srv, _, cleanup := setupMCPServer(t) - defer cleanup() - - ctx := context.Background() - id := rememberAndID(t, srv, "memory to discard", "decision") - - req := toolRequest("yaad_feedback", map[string]any{ - "id": id, - "action": "discard", - }) - res, err := srv.handleFeedback(ctx, req) - if err != nil { - t.Fatalf("handleFeedback discard: %v", err) - } - if !strings.Contains(textContent(res), "discard") { - t.Errorf("expected 'discard' in result, got %q", textContent(res)) - } -} - -func TestMCPFeedbackInvalidAction(t *testing.T) { - srv, _, cleanup := setupMCPServer(t) - defer cleanup() - - ctx := context.Background() - id := rememberAndID(t, srv, "some memory", "decision") - - req := toolRequest("yaad_feedback", map[string]any{ - "id": id, - "action": "invalid_action", - }) - _, err := srv.handleFeedback(ctx, req) - if err == nil { - t.Fatal("expected error for invalid feedback action") - } -} - -// --------------------------------------------------------------------------- -// 4. Error handling — malformed requests, edge cases -// --------------------------------------------------------------------------- - -func TestMCPRecallEmptyQuery(t *testing.T) { - srv, _, cleanup := setupMCPServer(t) - defer cleanup() - - ctx := context.Background() - // Empty query should not error — it should list nodes. - req := toolRequest("yaad_recall", map[string]any{ - "query": "", - "limit": float64(5), - }) - res, err := srv.handleRecall(ctx, req) - if err != nil { - t.Fatalf("recall with empty query should not error: %v", err) - } - if res == nil { - t.Fatal("expected non-nil result") - } -} - -func TestMCPContext(t *testing.T) { - srv, _, cleanup := setupMCPServer(t) - defer cleanup() - - ctx := context.Background() - rememberAndID(t, srv, "context test", "decision") - - req := toolRequest("yaad_context", map[string]any{}) - res, err := srv.handleContext(ctx, req) - if err != nil { - t.Fatalf("handleContext: %v", err) - } - if res == nil { - t.Fatal("expected non-nil result") - } -} - -func TestMCPContextCanceled(t *testing.T) { - srv, _, cleanup := setupMCPServer(t) - defer cleanup() - - ctx, cancel := context.WithCancel(context.Background()) - cancel() // cancel immediately - - req := toolRequest("yaad_context", map[string]any{}) - _, err := srv.handleContext(ctx, req) - if err == nil { - t.Fatal("expected error from canceled context") - } -} - -func TestMCPRecallCanceledContext(t *testing.T) { - srv, _, cleanup := setupMCPServer(t) - defer cleanup() - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - req := toolRequest("yaad_recall", map[string]any{"query": "test"}) - _, err := srv.handleRecall(ctx, req) - if err == nil { - t.Fatal("expected error from canceled context") - } -} - -func TestMCPRememberCanceledContext(t *testing.T) { - srv, _, cleanup := setupMCPServer(t) - defer cleanup() - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - req := toolRequest("yaad_remember", map[string]any{"content": "test"}) - _, err := srv.handleRemember(ctx, req) - if err == nil { - t.Fatal("expected error from canceled context") - } -} - -func TestMCPForgetCanceledContext(t *testing.T) { - srv, _, cleanup := setupMCPServer(t) - defer cleanup() - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - req := toolRequest("yaad_forget", map[string]any{"id": "any"}) - _, err := srv.handleForget(ctx, req) - if err == nil { - t.Fatal("expected error from canceled context") - } -} - -func TestMCPImpactCanceledContext(t *testing.T) { - srv, _, cleanup := setupMCPServer(t) - defer cleanup() - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - req := toolRequest("yaad_impact", map[string]any{"file": "/some/file.go"}) - _, err := srv.handleImpact(ctx, req) - if err == nil { - t.Fatal("expected error from canceled context") - } -} - -func TestMCPSubgraphDepthClamp(t *testing.T) { - srv, _, cleanup := setupMCPServer(t) - defer cleanup() - - ctx := context.Background() - id := rememberAndID(t, srv, "depth test", "decision") - - // depth > 5 should be clamped to 2. - req := toolRequest("yaad_subgraph", map[string]any{ - "id": id, - "depth": float64(99), - }) - res, err := srv.handleSubgraph(ctx, req) - if err != nil { - t.Fatalf("subgraph with clamped depth should not error: %v", err) - } - if res == nil { - t.Fatal("expected non-nil result") - } -} - -func TestMCPImpactDepthClamp(t *testing.T) { - srv, _, cleanup := setupMCPServer(t) - defer cleanup() - - ctx := context.Background() - // depth > 5 should be clamped to 3. - req := toolRequest("yaad_impact", map[string]any{ - "file": "/some/file.go", - "depth": float64(99), - }) - res, err := srv.handleImpact(ctx, req) - if err != nil { - t.Fatalf("impact with clamped depth should not error: %v", err) - } - if res == nil { - t.Fatal("expected non-nil result") - } -} - -func TestMCPRememberContentTooLong(t *testing.T) { - srv, _, cleanup := setupMCPServer(t) - defer cleanup() - - ctx := context.Background() - longContent := strings.Repeat("x", 20000) - req := toolRequest("yaad_remember", map[string]any{ - "content": longContent, - }) - _, err := srv.handleRemember(ctx, req) - if err == nil { - t.Fatal("expected error for content exceeding max length") - } - if !strings.Contains(err.Error(), "max length") { - t.Errorf("error = %q, want 'max length'", err.Error()) - } -} - -func TestMCPVerifyWithoutIntegrity(t *testing.T) { - srv, _, cleanup := setupMCPServer(t) - defer cleanup() - - ctx := context.Background() - // Engine is created without WithIntegrity, so verify should fail. - req := toolRequest("yaad_verify", map[string]any{}) - _, err := srv.handleVerify(ctx, req) - if err == nil { - t.Fatal("expected error when integrity checker not configured") - } - if !strings.Contains(err.Error(), "integrity checker not configured") { - t.Errorf("error = %q, want 'integrity checker not configured'", err.Error()) - } -} - -// --------------------------------------------------------------------------- -// 5. Concurrency — parallel memory operations -// --------------------------------------------------------------------------- - -func TestMCPConcurrentRemember(t *testing.T) { - srv, _, cleanup := setupMCPServer(t) - defer cleanup() - - const n = 10 - var wg sync.WaitGroup - var mu sync.Mutex - successes := 0 - sqliteBusy := 0 - - for i := 0; i < n; i++ { - wg.Add(1) - go func(idx int) { - defer wg.Done() - ctx := context.Background() - req := toolRequest("yaad_remember", map[string]any{ - "content": fmt.Sprintf("concurrent memory %d", idx), - "type": "decision", - }) - _, err := srv.handleRemember(ctx, req) - mu.Lock() - defer mu.Unlock() - if err != nil { - if strings.Contains(err.Error(), "SQLITE_BUSY") { - sqliteBusy++ - } else { - t.Errorf("goroutine %d unexpected error: %v", idx, err) - } - } else { - successes++ - } - }(i) - } - - wg.Wait() - - // With SQLite WAL mode, some SQLITE_BUSY is expected under heavy concurrency. - // The engine's write mutex serializes Remember calls, but SelfLink runs - // outside the lock, so contention is still possible. - if successes == 0 { - t.Fatal("expected at least some concurrent remembers to succeed") - } - t.Logf("concurrent remember: %d/%d succeeded, %d SQLITE_BUSY", successes, n, sqliteBusy) - - // Verify at least the successful nodes exist. - // Retry ListNodes briefly in case of SQLITE_BUSY from SelfLink cleanup. - runtime.Gosched() - time.Sleep(20 * time.Millisecond) - ctx := context.Background() - nodes, err := srv.eng.Store().ListNodes(ctx, storage.NodeFilter{}) - if err != nil { - t.Logf("ListNodes SQLITE_BUSY (expected after concurrent writes): %v", err) - return - } - if len(nodes) < successes { - t.Errorf("expected at least %d nodes, got %d", successes, len(nodes)) - } -} - -func TestMCPConcurrentRememberAndForget(t *testing.T) { - srv, _, cleanup := setupMCPServer(t) - defer cleanup() - - ctx := context.Background() - // Pre-create some nodes to forget. - var ids []string - for i := 0; i < 5; i++ { - ids = append(ids, rememberAndID(t, srv, fmt.Sprintf("pre-existing %d", i), "decision")) - } - runtime.Gosched() - time.Sleep(10 * time.Millisecond) - - const n = 5 - var wg sync.WaitGroup - var mu sync.Mutex - nonBusyErrors := 0 - - // Concurrent remembers. - for i := 0; i < n; i++ { - wg.Add(1) - go func(idx int) { - defer wg.Done() - req := toolRequest("yaad_remember", map[string]any{ - "content": fmt.Sprintf("new memory %d", idx), - "type": "convention", - }) - _, err := srv.handleRemember(ctx, req) - if err != nil && !strings.Contains(err.Error(), "SQLITE_BUSY") { - mu.Lock() - nonBusyErrors++ - mu.Unlock() - t.Errorf("remember %d unexpected error: %v", idx, err) - } - }(i) - } - - // Concurrent forgets. - for i, id := range ids { - wg.Add(1) - go func(idx int, nodeID string) { - defer wg.Done() - req := toolRequest("yaad_forget", map[string]any{"id": nodeID}) - _, err := srv.handleForget(ctx, req) - if err != nil && !strings.Contains(err.Error(), "SQLITE_BUSY") { - mu.Lock() - nonBusyErrors++ - mu.Unlock() - t.Errorf("forget %d unexpected error: %v", idx, err) - } - }(i, id) - } - - wg.Wait() - - if nonBusyErrors > 0 { - t.Errorf("got %d non-SQLITE_BUSY errors from concurrent operations", nonBusyErrors) - } -} - -func TestMCPConcurrentLinkAndRecall(t *testing.T) { - srv, _, cleanup := setupMCPServer(t) - defer cleanup() - - ctx := context.Background() - idA := rememberAndID(t, srv, "link source", "decision") - idB := rememberAndID(t, srv, "link target", "convention") - runtime.Gosched() - time.Sleep(5 * time.Millisecond) - - var wg sync.WaitGroup - var mu sync.Mutex - nonBusyErrors := 0 - - // Concurrent links. - for i := 0; i < 5; i++ { - wg.Add(1) - go func(idx int) { - defer wg.Done() - req := toolRequest("yaad_link", map[string]any{ - "from_id": idA, - "to_id": idB, - "type": "relates_to", - }) - // Linking the same pair may fail due to duplicate edge ID; that's OK. - srv.handleLink(ctx, req) - }(i) - } - - // Concurrent recalls. - for i := 0; i < 5; i++ { - wg.Add(1) - go func(idx int) { - defer wg.Done() - req := toolRequest("yaad_recall", map[string]any{ - "query": "link", - "limit": float64(5), - }) - _, err := srv.handleRecall(ctx, req) - if err != nil && !strings.Contains(err.Error(), "SQLITE_BUSY") { - mu.Lock() - nonBusyErrors++ - mu.Unlock() - t.Errorf("recall %d unexpected error: %v", idx, err) - } - }(i) - } - - wg.Wait() - - if nonBusyErrors > 0 { - t.Errorf("got %d non-SQLITE_BUSY errors from concurrent link/recall", nonBusyErrors) - } -} - -// --------------------------------------------------------------------------- -// 6. Authentication — API key validation on REST server -// --------------------------------------------------------------------------- - -func TestRESTAPIKeyRequired(t *testing.T) { - _, eng, cleanup := setupMCPServer(t) - defer cleanup() - - srv := NewRESTServer(eng, "") - srv.WithAPIKey("test-secret-key-12345") - - mux := http.NewServeMux() - srv.RegisterRoutes(mux) - wrapped := srv.withMiddleware(mux) - - // Request without API key should be unauthorized. - body := `{"content":"test","type":"decision"}` - req := httptest.NewRequest("POST", "/yaad/remember", bytes.NewBufferString(body)) - req.Header.Set("Content-Type", "application/json") - rr := httptest.NewRecorder() - wrapped.ServeHTTP(rr, req) - - if rr.Code != 401 { - t.Errorf("no API key: expected 401, got %d: %s", rr.Code, rr.Body.String()) - } -} - -func TestRESTAPIKeyBearerAuth(t *testing.T) { - _, eng, cleanup := setupMCPServer(t) - defer cleanup() - - srv := NewRESTServer(eng, "") - srv.WithAPIKey("test-secret-key-12345") - - mux := http.NewServeMux() - srv.RegisterRoutes(mux) - wrapped := srv.withMiddleware(mux) - - // Request with correct Bearer token should succeed. - body := `{"content":"authenticated memory","type":"decision"}` - req := httptest.NewRequest("POST", "/yaad/remember", bytes.NewBufferString(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer test-secret-key-12345") - rr := httptest.NewRecorder() - wrapped.ServeHTTP(rr, req) - - if rr.Code != 201 { - t.Errorf("valid Bearer: expected 201, got %d: %s", rr.Code, rr.Body.String()) - } -} - -func TestRESTAPIKeyXAPIKeyHeader(t *testing.T) { - _, eng, cleanup := setupMCPServer(t) - defer cleanup() - - srv := NewRESTServer(eng, "") - srv.WithAPIKey("test-secret-key-12345") - - mux := http.NewServeMux() - srv.RegisterRoutes(mux) - wrapped := srv.withMiddleware(mux) - - // Request with X-API-Key header should succeed. - body := `{"content":"x-api-key memory","type":"decision"}` - req := httptest.NewRequest("POST", "/yaad/remember", bytes.NewBufferString(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("X-API-Key", "test-secret-key-12345") - rr := httptest.NewRecorder() - wrapped.ServeHTTP(rr, req) - - if rr.Code != 201 { - t.Errorf("valid X-API-Key: expected 201, got %d: %s", rr.Code, rr.Body.String()) - } -} - -func TestRESTAPIKeyWrongKey(t *testing.T) { - _, eng, cleanup := setupMCPServer(t) - defer cleanup() - - srv := NewRESTServer(eng, "") - srv.WithAPIKey("correct-key") - - mux := http.NewServeMux() - srv.RegisterRoutes(mux) - wrapped := srv.withMiddleware(mux) - - body := `{"content":"wrong key","type":"decision"}` - req := httptest.NewRequest("POST", "/yaad/remember", bytes.NewBufferString(body)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer wrong-key") - rr := httptest.NewRecorder() - wrapped.ServeHTTP(rr, req) - - if rr.Code != 401 { - t.Errorf("wrong key: expected 401, got %d: %s", rr.Code, rr.Body.String()) - } -} - -func TestRESTHealthSkipsAuth(t *testing.T) { - _, eng, cleanup := setupMCPServer(t) - defer cleanup() - - srv := NewRESTServer(eng, "") - srv.WithAPIKey("test-secret-key-12345") - - mux := http.NewServeMux() - srv.RegisterRoutes(mux) - wrapped := srv.withMiddleware(mux) - - // Health endpoint should not require auth. - req := httptest.NewRequest("GET", "/yaad/health", nil) - rr := httptest.NewRecorder() - wrapped.ServeHTTP(rr, req) - - if rr.Code != 200 { - t.Errorf("health no auth: expected 200, got %d: %s", rr.Code, rr.Body.String()) - } -} - -func TestRESTNoAPIKeyAllowsAll(t *testing.T) { - _, eng, cleanup := setupMCPServer(t) - defer cleanup() - - srv := NewRESTServer(eng, "") - // No WithAPIKey — all requests should be allowed. - - mux := http.NewServeMux() - srv.RegisterRoutes(mux) - - body := `{"content":"no auth needed","type":"decision"}` - req := httptest.NewRequest("POST", "/yaad/remember", bytes.NewBufferString(body)) - req.Header.Set("Content-Type", "application/json") - rr := httptest.NewRecorder() - mux.ServeHTTP(rr, req) - - if rr.Code != 201 { - t.Errorf("no API key configured: expected 201, got %d: %s", rr.Code, rr.Body.String()) - } -} - -// --------------------------------------------------------------------------- -// Additional edge cases -// --------------------------------------------------------------------------- - -func TestMCPSessionRecapNoSessions(t *testing.T) { - srv, _, cleanup := setupMCPServer(t) - defer cleanup() - - ctx := context.Background() - req := toolRequest("yaad_session_recap", map[string]any{}) - res, err := srv.handleSessionRecap(ctx, req) - if err != nil { - t.Fatalf("handleSessionRecap: %v", err) - } - text := textContent(res) - if !strings.Contains(text, "No previous sessions") { - t.Errorf("expected 'No previous sessions', got %q", text) - } -} - -func TestMCPCompact(t *testing.T) { - srv, _, cleanup := setupMCPServer(t) - defer cleanup() - - ctx := context.Background() - req := toolRequest("yaad_compact", map[string]any{}) - res, err := srv.handleCompact(ctx, req) - if err != nil { - t.Fatalf("handleCompact: %v", err) - } - text := textContent(res) - if !strings.Contains(text, "compacted") { - t.Errorf("expected 'compacted' in result, got %q", text) - } -} - -func TestMCPProactive(t *testing.T) { - srv, _, cleanup := setupMCPServer(t) - defer cleanup() - - ctx := context.Background() - rememberAndID(t, srv, "proactive test memory", "decision") - - req := toolRequest("yaad_proactive", map[string]any{"budget": float64(500)}) - res, err := srv.handleProactive(ctx, req) - if err != nil { - t.Fatalf("handleProactive: %v", err) - } - if res == nil { - t.Fatal("expected non-nil result") - } -} - -func TestMCPMentalModel(t *testing.T) { - srv, _, cleanup := setupMCPServer(t) - defer cleanup() - - ctx := context.Background() - rememberAndID(t, srv, "We use Go for backend services", "convention") - - req := toolRequest("yaad_mental_model", map[string]any{}) - res, err := srv.handleMentalModel(ctx, req) - if err != nil { - t.Fatalf("handleMentalModel: %v", err) - } - if res == nil { - t.Fatal("expected non-nil result") - } -} - -func TestMCPProfile(t *testing.T) { - srv, _, cleanup := setupMCPServer(t) - defer cleanup() - - ctx := context.Background() - rememberAndID(t, srv, "developer prefers vim", "preference") - - req := toolRequest("yaad_profile", map[string]any{}) - res, err := srv.handleProfile(ctx, req) - if err != nil { - t.Fatalf("handleProfile: %v", err) - } - if res == nil { - t.Fatal("expected non-nil result") - } -} - -// TestMCPToolRegistration verifies that all expected tools are registered -// and that the tool count matches expectations. -func TestMCPToolRegistration(t *testing.T) { - srv, _, cleanup := setupMCPServer(t) - defer cleanup() - - // The MCP server registers tools via AddTool. We verify by checking that - // calling each known tool handler does not panic and that the server struct - // was constructed successfully. - if srv.server == nil { - t.Fatal("expected non-nil mcp-go server") - } - if srv.eng == nil { - t.Fatal("expected non-nil engine") - } -} - -// TestMCPRoundTripRememberRecall stores a memory and verifies it can be found -// via recall, testing the full create-search pipeline. -func TestMCPRoundTripRememberRecall(t *testing.T) { - srv, _, cleanup := setupMCPServer(t) - defer cleanup() - - ctx := context.Background() - - // Store a distinctive memory. - _, err := srv.handleRemember(ctx, toolRequest("yaad_remember", map[string]any{ - "content": "Hawk uses SQLite for persistent graph storage with WAL mode", - "type": "convention", - "tags": "sqlite,storage,architecture", - })) - if err != nil { - t.Fatalf("remember: %v", err) - } - // Allow SelfLink async edge writes to complete. - runtime.Gosched() - time.Sleep(20 * time.Millisecond) - - // Recall it. - res, err := srv.handleRecall(ctx, toolRequest("yaad_recall", map[string]any{ - "query": "SQLite persistent storage", - "depth": float64(2), - "limit": float64(10), - })) - if err != nil { - t.Fatalf("recall: %v", err) - } - - text := textContent(res) - if text == "" || text == "null" { - t.Fatal("expected recall to find the stored memory") - } - if !strings.Contains(text, "SQLite") { - t.Errorf("recall result should contain 'SQLite', got %q", text[:min(200, len(text))]) - } -} - func min(a, b int) int { if a < b { return a diff --git a/internal/server/rest.go b/internal/server/rest.go index 35a165d..9692b8c 100644 --- a/internal/server/rest.go +++ b/internal/server/rest.go @@ -9,30 +9,26 @@ import ( "encoding/json" "errors" "fmt" - "io" "log/slog" "net/http" "net/url" - "path/filepath" "strconv" "strings" "time" "github.com/GrayCodeAI/yaad/embeddings" "github.com/GrayCodeAI/yaad/engine" - "github.com/GrayCodeAI/yaad/exportimport" - gitwatch "github.com/GrayCodeAI/yaad/git" - "github.com/GrayCodeAI/yaad/graph" - "github.com/GrayCodeAI/yaad/internal/bench" "github.com/GrayCodeAI/yaad/internal/telemetry" - "github.com/GrayCodeAI/yaad/internal/version" - "github.com/GrayCodeAI/yaad/skill" - "github.com/GrayCodeAI/yaad/storage" - "github.com/google/uuid" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" ) +// Note: the REST request handlers were moved verbatim into +// rest_handlers.go (core memory/graph/session/ops handlers) and +// rest_extra.go (export/import, skill, version-history, and advanced +// feature handlers) for readability. This file keeps the server +// lifecycle, middleware, route registration, and shared HTTP helpers. + const ( maxRequestBodySize = 1 << 20 // 1 MB maxResponseSize = 5 * (1 << 20) // 5 MB @@ -318,654 +314,6 @@ func (s *RESTServer) RegisterRoutes(mux *http.ServeMux) { mux.HandleFunc("GET /yaad/watch", s.handleWatchMemories) } -func (s *RESTServer) handleRemember(w http.ResponseWriter, r *http.Request) { - var in engine.RememberInput - if err := decodeJSON(r, &in); err != nil { - var maxBytesErr *http.MaxBytesError - if errors.As(err, &maxBytesErr) { - httpErr(w, err, 413) - } else { - httpErr(w, err, 400) - } - return - } - if in.Type != "" && !engine.IsValidNodeType(in.Type) { - httpErr(w, fmt.Errorf("invalid node type: %q", in.Type), 400) - return - } - node, err := s.eng.Remember(r.Context(), in) - if err != nil { - httpErr(w, err, 500) - return - } - httpJSON(w, node, 201) -} - -func (s *RESTServer) handleRecall(w http.ResponseWriter, r *http.Request) { - var opts engine.RecallOpts - if err := decodeJSON(r, &opts); err != nil { - var maxBytesErr *http.MaxBytesError - if errors.As(err, &maxBytesErr) { - httpErr(w, err, 413) - } else { - httpErr(w, err, 400) - } - return - } - if opts.Depth > maxGraphDepth { - httpErr(w, fmt.Errorf("depth exceeds maximum of %d", maxGraphDepth), 400) - return - } - result, err := s.eng.Recall(r.Context(), opts) - if err != nil { - httpErr(w, err, 500) - return - } - httpJSONCapped(w, result, 200) -} - -func (s *RESTServer) handleContext(w http.ResponseWriter, r *http.Request) { - project := r.URL.Query().Get("project") - result, err := s.eng.Context(r.Context(), project) - if err != nil { - httpErr(w, err, 500) - return - } - httpJSONCapped(w, result, 200) -} - -func (s *RESTServer) handleLink(w http.ResponseWriter, r *http.Request) { - var edge storage.Edge - if err := decodeJSON(r, &edge); err != nil { - var maxBytesErr *http.MaxBytesError - if errors.As(err, &maxBytesErr) { - httpErr(w, err, 413) - } else { - httpErr(w, err, 400) - } - return - } - if edge.Type == "" { - httpErr(w, fmt.Errorf("edge type is required"), 400) - return - } - if !graph.IsValidEdgeType(edge.Type) { - httpErr(w, fmt.Errorf("invalid edge type: %q", edge.Type), 400) - return - } - if edge.ID == "" { - edge.ID = uuid.New().String() - } - if err := s.eng.Graph().AddEdge(r.Context(), &edge); err != nil { - httpErr(w, err, 400) - return - } - httpJSON(w, edge, 201) -} - -// rateLimit returns 429 if the rate limiter rejects the request. -func (s *RESTServer) rateLimit(w http.ResponseWriter) bool { - if s.limiter != nil && !s.limiter.Allow() { - httpErr(w, fmt.Errorf("rate limit exceeded, try again later"), 429) - return false - } - return true -} - -func (s *RESTServer) handleDeleteLink(w http.ResponseWriter, r *http.Request) { - if !s.rateLimit(w) { - return - } - id := r.PathValue("id") - if err := s.eng.Graph().RemoveEdge(r.Context(), id); err != nil { - httpErr(w, err, 404) - return - } - httpJSON(w, map[string]string{"status": "deleted"}, 200) -} - -func (s *RESTServer) handleGetNode(w http.ResponseWriter, r *http.Request) { - start := time.Now() - id := r.PathValue("id") - node, err := s.eng.Store().GetNode(r.Context(), id) - telemetry.MemoryRetrieveDuration.Record(r.Context(), time.Since(start).Seconds()) - if err != nil { - httpErr(w, err, 404) - return - } - neighbors, _ := s.eng.Store().GetNeighbors(r.Context(), id) - httpJSON(w, map[string]any{"node": node, "neighbors": neighbors}, 200) -} - -func (s *RESTServer) handleSubgraph(w http.ResponseWriter, r *http.Request) { - id := r.PathValue("id") - depth := intQuery(r, "depth", 2) - if depth > maxGraphDepth { - httpErr(w, fmt.Errorf("depth exceeds maximum of %d", maxGraphDepth), 400) - return - } - sg, err := s.eng.Graph().ExtractSubgraph(r.Context(), id, depth) - if err != nil { - httpErr(w, err, 500) - return - } - httpJSONCapped(w, sg, 200) -} - -func (s *RESTServer) handleImpact(w http.ResponseWriter, r *http.Request) { - file := r.PathValue("file") - depth := intQuery(r, "depth", 3) - if depth > maxGraphDepth { - httpErr(w, fmt.Errorf("depth exceeds maximum of %d", maxGraphDepth), 400) - return - } - ids, err := s.eng.Graph().Impact(r.Context(), file, depth) - if err != nil { - httpErr(w, err, 500) - return - } - var nodes []*storage.Node - for _, id := range ids { - if n, err := s.eng.Store().GetNode(r.Context(), id); err == nil { - nodes = append(nodes, n) - } - } - httpJSONCapped(w, nodes, 200) -} - -func (s *RESTServer) handleForget(w http.ResponseWriter, r *http.Request) { - if !s.rateLimit(w) { - return - } - id := r.PathValue("id") - if err := s.eng.Forget(r.Context(), id); err != nil { - httpErr(w, err, 404) - return - } - httpJSON(w, map[string]string{"status": "forgotten"}, 200) -} - -func (s *RESTServer) handleUpdateNode(w http.ResponseWriter, r *http.Request) { - id := r.PathValue("id") - node, err := s.eng.Store().GetNode(r.Context(), id) - if err != nil { - httpErr(w, err, 404) - return - } - - var patch struct { - Content *string `json:"content"` - Summary *string `json:"summary"` - Tags *string `json:"tags"` - Key *string `json:"key"` - Pinned *bool `json:"pinned"` - Type *string `json:"type"` - Tier *int `json:"tier"` - } - if err := decodeJSON(r, &patch); err != nil { - httpErr(w, err, 400) - return - } - - // Save version before modifying - if patch.Content != nil && *patch.Content != node.Content { - _ = s.eng.Store().SaveVersion(r.Context(), node.ID, node.Content, "api", "updated via PATCH") - } - - if patch.Content != nil { - node.Content = *patch.Content - node.ContentHash = engine.ContentHash(node.Content, node.Scope, node.Project, node.Type) - } - if patch.Summary != nil { - node.Summary = *patch.Summary - } - if patch.Tags != nil { - node.Tags = *patch.Tags - } - if patch.Key != nil { - node.Key = *patch.Key - } - if patch.Pinned != nil { - node.Pinned = *patch.Pinned - } - if patch.Type != nil { - if !engine.IsValidNodeType(*patch.Type) { - httpErr(w, fmt.Errorf("invalid node type: %q", *patch.Type), 400) - return - } - node.Type = *patch.Type - } - if patch.Tier != nil { - node.Tier = *patch.Tier - } - node.Version++ - if err := s.eng.Store().UpdateNode(r.Context(), node); err != nil { - httpErr(w, err, 500) - return - } - httpJSON(w, node, 200) -} - -func (s *RESTServer) handlePinNode(w http.ResponseWriter, r *http.Request) { - id := r.PathValue("id") - node, err := s.eng.Store().GetNode(r.Context(), id) - if err != nil { - httpErr(w, err, 404) - return - } - node.Pinned = !node.Pinned - if err := s.eng.Store().UpdateNode(r.Context(), node); err != nil { - httpErr(w, err, 500) - return - } - httpJSON(w, map[string]any{"id": node.ID, "pinned": node.Pinned}, 200) -} - -func (s *RESTServer) handleHealth(w http.ResponseWriter, r *http.Request) { - // Actually verify database connectivity with a lightweight query - _, err := s.eng.Store().ListNodes(r.Context(), storage.NodeFilter{}) - if err != nil { - httpJSON(w, map[string]string{"status": "error", "error": err.Error()}, 503) - return - } - httpJSON(w, map[string]string{"status": "ok", "version": version.String()}, 200) -} - -func (s *RESTServer) handleVersion(w http.ResponseWriter, _ *http.Request) { - httpJSON(w, map[string]string{"version": version.String()}, 200) -} - -// handleLiveness responds to Kubernetes liveness probes (/healthz). -// Always returns 200 — if the process is alive, it is live. -func (s *RESTServer) handleLiveness(w http.ResponseWriter, _ *http.Request) { - httpJSON(w, map[string]string{"status": "ok"}, 200) -} - -// handleReadiness responds to Kubernetes readiness probes (/readyz). -// Delegates to the full health check to confirm the database is reachable. -func (s *RESTServer) handleReadiness(w http.ResponseWriter, r *http.Request) { - s.handleHealth(w, r) -} - -func (s *RESTServer) handleStats(w http.ResponseWriter, r *http.Request) { - project := r.URL.Query().Get("project") - st, err := s.eng.Status(r.Context(), project) - if err != nil { - httpErr(w, err, 500) - return - } - httpJSON(w, st, 200) -} - -func (s *RESTServer) handleSessions(w http.ResponseWriter, r *http.Request) { - project := r.URL.Query().Get("project") - limit := intQuery(r, "limit", 10) - sessions, err := s.eng.Store().ListSessions(r.Context(), project, limit) - if err != nil { - httpErr(w, err, 500) - return - } - httpJSON(w, sessions, 200) -} - -func (s *RESTServer) handleSessionStart(w http.ResponseWriter, r *http.Request) { - var body struct { - Project string `json:"project"` - Agent string `json:"agent"` - } - if err := decodeJSON(r, &body); err != nil { - var maxBytesErr *http.MaxBytesError - if errors.As(err, &maxBytesErr) { - httpErr(w, err, 413) - } else { - httpErr(w, err, 400) - } - return - } - sess := &storage.Session{ - ID: uuid.New().String(), - Project: body.Project, - Agent: body.Agent, - StartedAt: time.Now(), - } - if err := s.eng.Store().CreateSession(r.Context(), sess); err != nil { - httpErr(w, err, 500) - return - } - ctxRes, err := s.eng.Context(r.Context(), body.Project) - if err != nil { - slog.Warn("session start: context load failed", "error", err) - } - httpJSON(w, map[string]any{"session": sess, "context": ctxRes}, 201) -} - -func (s *RESTServer) handleSessionEnd(w http.ResponseWriter, r *http.Request) { - var body struct { - ID string `json:"id"` - Summary string `json:"summary"` - } - if err := decodeJSON(r, &body); err != nil { - var maxBytesErr *http.MaxBytesError - if errors.As(err, &maxBytesErr) { - httpErr(w, err, 413) - } else { - httpErr(w, err, 400) - } - return - } - if err := s.eng.Store().EndSession(r.Context(), body.ID, body.Summary); err != nil { - httpErr(w, err, 500) - return - } - httpJSON(w, map[string]string{"status": "ended"}, 200) -} - -func (s *RESTServer) handleStale(w http.ResponseWriter, r *http.Request) { - if s.projectDir == "" { - httpJSON(w, map[string]string{"status": "no project directory configured"}, 200) - return - } - // gitwatch.New is lightweight (path validation + struct creation only). - watcher, err := gitwatch.New(s.eng.Store(), s.eng.Graph(), s.projectDir) - if err != nil { - httpErr(w, err, 500) - return - } - since := time.Now().Add(-7 * 24 * time.Hour) // last 7 days - reports, err := watcher.StalesSince(r.Context(), since) - if err != nil { - httpErr(w, err, 500) - return - } - httpJSON(w, reports, 200) -} - -func (s *RESTServer) handleEmbed(w http.ResponseWriter, r *http.Request) { - var body struct { - NodeID string `json:"node_id"` - } - if err := decodeJSON(r, &body); err != nil { - var maxBytesErr *http.MaxBytesError - if errors.As(err, &maxBytesErr) { - httpErr(w, err, 413) - } else { - httpErr(w, err, 400) - } - return - } - if s.embedder == nil { - httpErr(w, fmt.Errorf("no embedding provider configured"), 503) - return - } - node, err := s.eng.Store().GetNode(r.Context(), body.NodeID) - if err != nil { - httpErr(w, err, 404) - return - } - // Document mode: stored content is embedded as a document (search_document) - // so it pairs correctly with query-mode embeddings at search time. - vec, err := s.embedder.EmbedWithMode(r.Context(), node.Content, embeddings.ModeDocument) - if err != nil { - httpErr(w, err, 500) - return - } - if err := s.eng.Store().SaveEmbedding(r.Context(), node.ID, s.embedder.Name(), vec); err != nil { - httpErr(w, err, 500) - return - } - httpJSON(w, map[string]any{"node_id": node.ID, "dims": len(vec)}, 200) -} - -func (s *RESTServer) handleHybridRecall(w http.ResponseWriter, r *http.Request) { - var opts engine.RecallOpts - if err := decodeJSON(r, &opts); err != nil { - var maxBytesErr *http.MaxBytesError - if errors.As(err, &maxBytesErr) { - httpErr(w, err, 413) - } else { - httpErr(w, err, 400) - } - return - } - hs := engine.NewHybridSearch(s.eng.Store(), s.eng.Graph(), s.embedder) - scored, err := hs.Search(r.Context(), opts.Query, opts) - if err != nil { - httpErr(w, err, 500) - return - } - reranked := engine.Rerank(r.Context(), scored, s.eng.Store()) - httpJSONCapped(w, reranked, 200) -} - -func (s *RESTServer) handleProactive(w http.ResponseWriter, r *http.Request) { - project := r.URL.Query().Get("project") - hs := engine.NewHybridSearch(s.eng.Store(), s.eng.Graph(), s.embedder) - pc := engine.NewProactiveContext(s.eng, hs) - nodes, err := pc.Predict(r.Context(), project, 2000) - if err != nil { - httpErr(w, err, 500) - return - } - httpJSON(w, map[string]any{ - "nodes": nodes, - "context": engine.FormatContext(nodes), - }, 200) -} - -func (s *RESTServer) handleFeedback(w http.ResponseWriter, r *http.Request) { - var body struct { - ID string `json:"id"` - Action engine.FeedbackAction `json:"action"` - NewContent string `json:"new_content"` - } - if err := decodeJSON(r, &body); err != nil { - var maxBytesErr *http.MaxBytesError - if errors.As(err, &maxBytesErr) { - httpErr(w, err, 413) - } else { - httpErr(w, err, 400) - } - return - } - if err := s.eng.Feedback(r.Context(), body.ID, body.Action, body.NewContent); err != nil { - httpErr(w, err, 400) - return - } - httpJSON(w, map[string]string{"status": "ok"}, 200) -} - -func (s *RESTServer) handleDecay(w http.ResponseWriter, r *http.Request) { - if !s.rateLimit(w) { - return - } - if err := engine.RunDecay(r.Context(), s.eng.Store(), s.eng.DecayConfig); err != nil { - httpErr(w, err, 500) - return - } - httpJSON(w, map[string]string{"status": "decay applied"}, 200) -} - -func (s *RESTServer) handleGC(w http.ResponseWriter, r *http.Request) { - if !s.rateLimit(w) { - return - } - n, err := engine.GarbageCollect(r.Context(), s.eng.Store(), s.eng.DecayConfig) - if err != nil { - httpErr(w, err, 500) - return - } - httpJSON(w, map[string]int{"removed": n}, 200) -} - -func (s *RESTServer) handleReplay(w http.ResponseWriter, r *http.Request) { - sessionID := r.PathValue("session_id") - events, err := s.eng.Store().GetReplayEvents(r.Context(), sessionID) - if err != nil { - httpErr(w, err, 500) - return - } - httpJSON(w, events, 200) -} - -func (s *RESTServer) handleExportJSON(w http.ResponseWriter, r *http.Request) { - project := r.URL.Query().Get("project") - data, err := exportimport.ExportJSON(r.Context(), s.eng.Store(), project) - if err != nil { - httpErr(w, err, 500) - return - } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(200) - _, _ = w.Write(data) -} - -func (s *RESTServer) handleExportMarkdown(w http.ResponseWriter, r *http.Request) { - project := r.URL.Query().Get("project") - md, err := exportimport.ExportMarkdown(r.Context(), s.eng.Store(), project) - if err != nil { - httpErr(w, err, 500) - return - } - w.Header().Set("Content-Type", "text/markdown") - w.WriteHeader(200) - _, _ = fmt.Fprint(w, md) -} - -func (s *RESTServer) handleExportObsidian(w http.ResponseWriter, r *http.Request) { - var body struct { - Project string `json:"project"` - VaultDir string `json:"vault_dir"` - } - if err := decodeJSON(r, &body); err != nil { - var maxBytesErr *http.MaxBytesError - if errors.As(err, &maxBytesErr) { - httpErr(w, err, 413) - } else { - httpErr(w, err, 400) - } - return - } - if body.VaultDir == "" { - httpErr(w, fmt.Errorf("vault_dir is required"), 400) - return - } - // Prevent path traversal — vault_dir must be absolute and not contain .. - cleanPath := filepath.Clean(body.VaultDir) - if cleanPath != body.VaultDir || !filepath.IsAbs(cleanPath) { - httpErr(w, fmt.Errorf("vault_dir must be a clean absolute path"), 400) - return - } - // Restrict to project directory if one is configured - if s.projectDir != "" { - projClean := filepath.Clean(s.projectDir) - if !strings.HasPrefix(cleanPath, projClean+string(filepath.Separator)) && cleanPath != projClean { - httpErr(w, fmt.Errorf("vault_dir must be within the project directory"), 400) - return - } - } - n, err := exportimport.ExportObsidian(r.Context(), s.eng.Store(), body.Project, cleanPath) - if err != nil { - httpErr(w, err, 500) - return - } - httpJSON(w, map[string]int{"written": n}, 200) -} - -func (s *RESTServer) handleImportJSON(w http.ResponseWriter, r *http.Request) { - data, err := io.ReadAll(io.LimitReader(r.Body, maxRequestBodySize)) - if err != nil { - var maxBytesErr *http.MaxBytesError - if errors.As(err, &maxBytesErr) { - httpErr(w, fmt.Errorf("request body exceeds %d bytes", maxRequestBodySize), 413) - return - } - httpErr(w, err, 400) - return - } - nodes, edges, err := exportimport.ImportJSON(r.Context(), s.eng.Store(), data) - if err != nil { - httpErr(w, err, 400) - return - } - httpJSON(w, map[string]int{"nodes": nodes, "edges": edges}, 200) -} - -func (s *RESTServer) handleSkillStore(w http.ResponseWriter, r *http.Request) { - var sk skill.Skill - if err := decodeJSON(r, &sk); err != nil { - var maxBytesErr *http.MaxBytesError - if errors.As(err, &maxBytesErr) { - httpErr(w, err, 413) - } else { - httpErr(w, err, 400) - } - return - } - project := r.URL.Query().Get("project") - node, err := skill.Store(r.Context(), s.eng, &sk, project) - if err != nil { - httpErr(w, err, 500) - return - } - httpJSON(w, node, 201) -} - -func (s *RESTServer) handleSkillList(w http.ResponseWriter, r *http.Request) { - project := r.URL.Query().Get("project") - skills, err := skill.ListSkills(r.Context(), s.eng.Store(), project) - if err != nil { - httpErr(w, err, 500) - return - } - httpJSON(w, skills, 200) -} - -func (s *RESTServer) handleSkillGet(w http.ResponseWriter, r *http.Request) { - name := r.PathValue("name") - project := r.URL.Query().Get("project") - sk, err := skill.Load(r.Context(), s.eng.Store(), name, project) - if err != nil { - httpErr(w, err, 404) - return - } - httpJSON(w, map[string]string{"replay": skill.Replay(sk)}, 200) -} - -func (s *RESTServer) handleBench(w http.ResponseWriter, r *http.Request) { - result := bench.Run(r.Context(), s.eng, bench.DefaultQAs(), 2, 10) - httpJSON(w, map[string]string{"report": result.String()}, 200) -} - -func (s *RESTServer) handleCompact(w http.ResponseWriter, r *http.Request) { - project := r.URL.Query().Get("project") - n, err := s.eng.Compact(r.Context(), project) - if err != nil { - httpErr(w, err, 500) - return - } - httpJSON(w, map[string]int{"compacted": n}, 200) -} - -func (s *RESTServer) handleMentalModel(w http.ResponseWriter, r *http.Request) { - project := r.URL.Query().Get("project") - model, err := s.eng.MentalModel(r.Context(), project) - if err != nil { - httpErr(w, err, 500) - return - } - httpJSON(w, map[string]any{"model": model, "formatted": model.Format()}, 200) -} - -func (s *RESTServer) handleProfile(w http.ResponseWriter, r *http.Request) { - project := r.URL.Query().Get("project") - p, err := s.eng.Profile(r.Context(), project) - if err != nil { - httpErr(w, err, 500) - return - } - httpJSON(w, map[string]any{"profile": p, "formatted": p.Format()}, 200) -} - // --- helpers --- func httpJSON(w http.ResponseWriter, v any, code int) { @@ -1026,345 +374,3 @@ func intQuery(r *http.Request, key string, def int) int { } return n } - -// --- Version history handlers --- - -func (s *RESTServer) handleVersions(w http.ResponseWriter, r *http.Request) { - id := r.PathValue("id") - history, err := s.eng.GetNodeHistory(r.Context(), id) - if err != nil { - httpErr(w, err, 404) - return - } - httpJSON(w, map[string]any{"node_id": id, "versions": history}, 200) -} - -func (s *RESTServer) handleRollback(w http.ResponseWriter, r *http.Request) { - if !s.rateLimit(w) { - return - } - id := r.PathValue("id") - var body struct { - Version int `json:"version"` - } - if err := decodeJSON(r, &body); err != nil { - httpErr(w, err, 400) - return - } - if body.Version <= 0 { - httpErr(w, fmt.Errorf("version must be positive"), 400) - return - } - if err := s.eng.Rollback(r.Context(), id, body.Version); err != nil { - httpErr(w, err, 400) - return - } - httpJSON(w, map[string]any{"status": "rolled_back", "node_id": id, "version": body.Version}, 200) -} - -func (s *RESTServer) handleDiff(w http.ResponseWriter, r *http.Request) { - id := r.PathValue("id") - v1 := intQuery(r, "v1", 0) - v2 := intQuery(r, "v2", 0) - if v1 <= 0 || v2 <= 0 { - httpErr(w, fmt.Errorf("v1 and v2 query params required (positive integers)"), 400) - return - } - c1, c2, err := s.eng.DiffVersions(r.Context(), id, v1, v2) - if err != nil { - httpErr(w, err, 400) - return - } - httpJSON(w, map[string]any{ - "node_id": id, - "v1": map[string]any{"version": v1, "content": c1}, - "v2": map[string]any{"version": v2, "content": c2}, - }, 200) -} - -// --- Confidence-scored recall handler --- - -func (s *RESTServer) handleRecallConfidence(w http.ResponseWriter, r *http.Request) { - var opts engine.RecallOpts - if err := decodeJSON(r, &opts); err != nil { - var maxBytesErr *http.MaxBytesError - if errors.As(err, &maxBytesErr) { - httpErr(w, err, 413) - } else { - httpErr(w, err, 400) - } - return - } - if opts.Depth > maxGraphDepth { - httpErr(w, fmt.Errorf("depth exceeds maximum of %d", maxGraphDepth), 400) - return - } - result, err := s.eng.RecallWithConfidence(r.Context(), opts) - if err != nil { - httpErr(w, err, 500) - return - } - httpJSONCapped(w, result, 200) -} - -// --- Session compression handler --- - -func (s *RESTServer) handleSessionCompress(w http.ResponseWriter, r *http.Request) { - var body struct { - SessionID string `json:"session_id"` - } - if err := decodeJSON(r, &body); err != nil { - httpErr(w, err, 400) - return - } - if body.SessionID == "" { - httpErr(w, fmt.Errorf("session_id is required"), 400) - return - } - n, err := s.eng.CompressSessionEvents(r.Context(), body.SessionID) - if err != nil { - httpErr(w, err, 500) - return - } - httpJSON(w, map[string]any{"compressed": n}, 200) -} - -// --- Agent file bridge handlers --- - -func (s *RESTServer) handleBridgeImport(w http.ResponseWriter, r *http.Request) { - if s.projectDir == "" { - httpErr(w, fmt.Errorf("no project directory configured"), 400) - return - } - bridge := engine.NewAgentFileBridge(s.projectDir) - rules, err := bridge.Import() - if err != nil { - httpErr(w, err, 500) - return - } - httpJSON(w, map[string]any{"rules": rules, "count": len(rules)}, 200) -} - -func (s *RESTServer) handleBridgeExport(w http.ResponseWriter, r *http.Request) { - if s.projectDir == "" { - httpErr(w, fmt.Errorf("no project directory configured"), 400) - return - } - var body struct { - Rules []engine.AgentRule `json:"rules"` - FileType engine.AgentFileType `json:"file_type"` - } - if err := decodeJSON(r, &body); err != nil { - httpErr(w, err, 400) - return - } - if body.FileType == "" { - body.FileType = engine.FileClaudeMD - } - bridge := engine.NewAgentFileBridge(s.projectDir) - if err := bridge.Export(body.Rules, body.FileType); err != nil { - httpErr(w, err, 500) - return - } - httpJSON(w, map[string]string{"status": "exported"}, 200) -} - -func (s *RESTServer) handleBridgeSync(w http.ResponseWriter, r *http.Request) { - if s.projectDir == "" { - httpErr(w, fmt.Errorf("no project directory configured"), 400) - return - } - var body struct { - Conventions []string `json:"conventions"` - } - if err := decodeJSON(r, &body); err != nil { - httpErr(w, err, 400) - return - } - bridge := engine.NewAgentFileBridge(s.projectDir) - if err := bridge.Sync(body.Conventions); err != nil { - httpErr(w, err, 500) - return - } - httpJSON(w, map[string]string{"status": "synced"}, 200) -} - -// --- Advanced feature handlers --- - -func (s *RESTServer) handleCommunities(w http.ResponseWriter, r *http.Request) { - cd := engine.NewCommunityDetector(s.eng.Store()) - communities, err := cd.Detect(r.Context(), 10) - if err != nil { - httpErr(w, err, 500) - return - } - communities = cd.Summarize(r.Context(), communities) - httpJSON(w, communities, 200) -} - -func (s *RESTServer) handleHierarchy(w http.ResponseWriter, r *http.Request) { - hm := engine.NewHierarchicalMemory(s.eng.Store()) - if err := hm.Build(r.Context()); err != nil { - httpErr(w, err, 500) - return - } - - var req struct { - Query string `json:"query"` - Level int `json:"level"` - } - req.Level = -1 - _ = decodeJSON(r, &req) - - if req.Level < 0 && req.Query != "" { - req.Level = hm.RetrieveAdaptive(req.Query) - } - if req.Level < 0 { - req.Level = 1 - } - - if req.Query != "" { - clusters := hm.RetrieveAtLevel(req.Query, req.Level) - httpJSON(w, map[string]interface{}{ - "level": req.Level, "query": req.Query, "clusters": clusters, - }, 200) - return - } - - httpJSON(w, map[string]interface{}{ - "level": req.Level, "content": hm.FormatLevel(req.Level), - }, 200) -} - -func (s *RESTServer) handleSparsify(w http.ResponseWriter, r *http.Request) { - if !s.rateLimit(w) { - return - } - sp := engine.NewSparsifier(s.eng.Store()) - result, err := sp.Run(r.Context()) - if err != nil { - httpErr(w, err, 500) - return - } - httpJSON(w, map[string]interface{}{ - "merged": result.Merged, "compressed": result.Compressed, - "pruned": result.Pruned, "total": result.Merged + result.Compressed + result.Pruned, - }, 200) -} - -func (s *RESTServer) handleVerify(w http.ResponseWriter, r *http.Request) { - yaadDir := filepath.Join(s.projectDir, ".yaad") - if s.projectDir == "" { - yaadDir = ".yaad" - } - mi, err := engine.NewMemoryIntegrity(yaadDir) - if err != nil { - httpErr(w, err, 500) - return - } - nodes, err := s.eng.Store().ListNodes(r.Context(), storage.NodeFilter{Limit: 10000}) - if err != nil { - httpErr(w, err, 500) - return - } - storedSigs, err := s.eng.Store().GetAllSignatures(r.Context()) - if err != nil { - httpErr(w, err, 500) - return - } - tampered := mi.VerifyBatch(nodes, storedSigs) - status := "ok" - if len(tampered) > 0 { - status = "tampered" - } - httpJSON(w, map[string]interface{}{ - "status": status, "nodes_verified": len(nodes), "tampered": len(tampered), - }, 200) -} - -func (s *RESTServer) handleEntities(w http.ResponseWriter, r *http.Request) { - // Build entity index and return stats - idx := engine.NewEntityIndex() - nodes, err := s.eng.Store().ListNodes(r.Context(), storage.NodeFilter{Limit: 5000}) - if err != nil { - httpErr(w, err, 500) - return - } - for _, n := range nodes { - idx.IndexNode(n) - } - httpJSON(w, map[string]interface{}{ - "unique_entities": idx.Size(), - "nodes_indexed": len(nodes), - }, 200) -} - -func (s *RESTServer) handleDoctor(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - store := s.eng.Store() - - // Graph health — use SQL aggregation instead of loading all nodes - types, _, _ := store.NodeStats(ctx) - ds, _ := store.DoctorStats(ctx) - totalNodes := ds.TotalNodes - orphanCount := ds.Orphans - lowConfCount := ds.LowConfidence - pinnedCount := ds.Pinned - staleCount := 0 - totalEdges, _ := store.CountAllEdges(ctx) - coverage := 0 - if types["convention"] > 0 { - coverage += 20 - } - if types["decision"] > 0 { - coverage += 20 - } - if types["spec"] > 0 { - coverage += 20 - } - if types["task"] > 0 { - coverage += 15 - } - if types["bug"] > 0 { - coverage += 15 - } - if types["preference"] > 0 { - coverage += 10 - } - - // Recommendations - var recs []string - if orphanCount > totalNodes/4 { - recs = append(recs, "High orphan count — run 'yaad sparsify' to clean up disconnected memories") - } - if lowConfCount > totalNodes/3 { - recs = append(recs, "Many low-confidence memories — run 'yaad decay' and 'yaad gc'") - } - if types["convention"] == 0 { - recs = append(recs, "No conventions stored — teach your agent coding rules") - } - if types["decision"] == 0 { - recs = append(recs, "No decisions stored — record architecture choices") - } - if types["spec"] == 0 { - recs = append(recs, "No specs stored — document subsystem designs") - } - if pinnedCount == 0 { - recs = append(recs, "No pinned memories — pin critical conventions with 'yaad pin '") - } - if len(recs) == 0 { - recs = append(recs, "Memory graph looks healthy!") - } - - httpJSON(w, map[string]interface{}{ - "nodes": totalNodes, - "edges": totalEdges, - "orphans": orphanCount, - "low_confidence": lowConfCount, - "pinned": pinnedCount, - "stale": staleCount, - "coverage_score": coverage, - "types": types, - "recommendations": recs, - }, 200) -} diff --git a/internal/server/rest_extra.go b/internal/server/rest_extra.go new file mode 100644 index 0000000..762c3ed --- /dev/null +++ b/internal/server/rest_extra.go @@ -0,0 +1,522 @@ +// This file is part of package server. It holds the export/import, skill, +// version-history, confidence, compression, bridge, and advanced-feature +// REST handlers split verbatim out of rest.go for readability; behavior +// is unchanged. + +package server + +import ( + "errors" + "fmt" + "io" + "net/http" + "path/filepath" + "strings" + + "github.com/GrayCodeAI/yaad/engine" + "github.com/GrayCodeAI/yaad/exportimport" + "github.com/GrayCodeAI/yaad/internal/bench" + "github.com/GrayCodeAI/yaad/skill" + "github.com/GrayCodeAI/yaad/storage" +) + +func (s *RESTServer) handleExportJSON(w http.ResponseWriter, r *http.Request) { + project := r.URL.Query().Get("project") + data, err := exportimport.ExportJSON(r.Context(), s.eng.Store(), project) + if err != nil { + httpErr(w, err, 500) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + _, _ = w.Write(data) +} + +func (s *RESTServer) handleExportMarkdown(w http.ResponseWriter, r *http.Request) { + project := r.URL.Query().Get("project") + md, err := exportimport.ExportMarkdown(r.Context(), s.eng.Store(), project) + if err != nil { + httpErr(w, err, 500) + return + } + w.Header().Set("Content-Type", "text/markdown") + w.WriteHeader(200) + _, _ = fmt.Fprint(w, md) +} + +func (s *RESTServer) handleExportObsidian(w http.ResponseWriter, r *http.Request) { + var body struct { + Project string `json:"project"` + VaultDir string `json:"vault_dir"` + } + if err := decodeJSON(r, &body); err != nil { + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) { + httpErr(w, err, 413) + } else { + httpErr(w, err, 400) + } + return + } + if body.VaultDir == "" { + httpErr(w, fmt.Errorf("vault_dir is required"), 400) + return + } + // Prevent path traversal — vault_dir must be absolute and not contain .. + cleanPath := filepath.Clean(body.VaultDir) + if cleanPath != body.VaultDir || !filepath.IsAbs(cleanPath) { + httpErr(w, fmt.Errorf("vault_dir must be a clean absolute path"), 400) + return + } + // Restrict to project directory if one is configured + if s.projectDir != "" { + projClean := filepath.Clean(s.projectDir) + if !strings.HasPrefix(cleanPath, projClean+string(filepath.Separator)) && cleanPath != projClean { + httpErr(w, fmt.Errorf("vault_dir must be within the project directory"), 400) + return + } + } + n, err := exportimport.ExportObsidian(r.Context(), s.eng.Store(), body.Project, cleanPath) + if err != nil { + httpErr(w, err, 500) + return + } + httpJSON(w, map[string]int{"written": n}, 200) +} + +func (s *RESTServer) handleImportJSON(w http.ResponseWriter, r *http.Request) { + data, err := io.ReadAll(io.LimitReader(r.Body, maxRequestBodySize)) + if err != nil { + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) { + httpErr(w, fmt.Errorf("request body exceeds %d bytes", maxRequestBodySize), 413) + return + } + httpErr(w, err, 400) + return + } + nodes, edges, err := exportimport.ImportJSON(r.Context(), s.eng.Store(), data) + if err != nil { + httpErr(w, err, 400) + return + } + httpJSON(w, map[string]int{"nodes": nodes, "edges": edges}, 200) +} + +func (s *RESTServer) handleSkillStore(w http.ResponseWriter, r *http.Request) { + var sk skill.Skill + if err := decodeJSON(r, &sk); err != nil { + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) { + httpErr(w, err, 413) + } else { + httpErr(w, err, 400) + } + return + } + project := r.URL.Query().Get("project") + node, err := skill.Store(r.Context(), s.eng, &sk, project) + if err != nil { + httpErr(w, err, 500) + return + } + httpJSON(w, node, 201) +} + +func (s *RESTServer) handleSkillList(w http.ResponseWriter, r *http.Request) { + project := r.URL.Query().Get("project") + skills, err := skill.ListSkills(r.Context(), s.eng.Store(), project) + if err != nil { + httpErr(w, err, 500) + return + } + httpJSON(w, skills, 200) +} + +func (s *RESTServer) handleSkillGet(w http.ResponseWriter, r *http.Request) { + name := r.PathValue("name") + project := r.URL.Query().Get("project") + sk, err := skill.Load(r.Context(), s.eng.Store(), name, project) + if err != nil { + httpErr(w, err, 404) + return + } + httpJSON(w, map[string]string{"replay": skill.Replay(sk)}, 200) +} + +func (s *RESTServer) handleBench(w http.ResponseWriter, r *http.Request) { + result := bench.Run(r.Context(), s.eng, bench.DefaultQAs(), 2, 10) + httpJSON(w, map[string]string{"report": result.String()}, 200) +} + +func (s *RESTServer) handleCompact(w http.ResponseWriter, r *http.Request) { + project := r.URL.Query().Get("project") + n, err := s.eng.Compact(r.Context(), project) + if err != nil { + httpErr(w, err, 500) + return + } + httpJSON(w, map[string]int{"compacted": n}, 200) +} + +func (s *RESTServer) handleMentalModel(w http.ResponseWriter, r *http.Request) { + project := r.URL.Query().Get("project") + model, err := s.eng.MentalModel(r.Context(), project) + if err != nil { + httpErr(w, err, 500) + return + } + httpJSON(w, map[string]any{"model": model, "formatted": model.Format()}, 200) +} + +func (s *RESTServer) handleProfile(w http.ResponseWriter, r *http.Request) { + project := r.URL.Query().Get("project") + p, err := s.eng.Profile(r.Context(), project) + if err != nil { + httpErr(w, err, 500) + return + } + httpJSON(w, map[string]any{"profile": p, "formatted": p.Format()}, 200) +} + +// --- Version history handlers --- + +func (s *RESTServer) handleVersions(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + history, err := s.eng.GetNodeHistory(r.Context(), id) + if err != nil { + httpErr(w, err, 404) + return + } + httpJSON(w, map[string]any{"node_id": id, "versions": history}, 200) +} + +func (s *RESTServer) handleRollback(w http.ResponseWriter, r *http.Request) { + if !s.rateLimit(w) { + return + } + id := r.PathValue("id") + var body struct { + Version int `json:"version"` + } + if err := decodeJSON(r, &body); err != nil { + httpErr(w, err, 400) + return + } + if body.Version <= 0 { + httpErr(w, fmt.Errorf("version must be positive"), 400) + return + } + if err := s.eng.Rollback(r.Context(), id, body.Version); err != nil { + httpErr(w, err, 400) + return + } + httpJSON(w, map[string]any{"status": "rolled_back", "node_id": id, "version": body.Version}, 200) +} + +func (s *RESTServer) handleDiff(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + v1 := intQuery(r, "v1", 0) + v2 := intQuery(r, "v2", 0) + if v1 <= 0 || v2 <= 0 { + httpErr(w, fmt.Errorf("v1 and v2 query params required (positive integers)"), 400) + return + } + c1, c2, err := s.eng.DiffVersions(r.Context(), id, v1, v2) + if err != nil { + httpErr(w, err, 400) + return + } + httpJSON(w, map[string]any{ + "node_id": id, + "v1": map[string]any{"version": v1, "content": c1}, + "v2": map[string]any{"version": v2, "content": c2}, + }, 200) +} + +// --- Confidence-scored recall handler --- + +func (s *RESTServer) handleRecallConfidence(w http.ResponseWriter, r *http.Request) { + var opts engine.RecallOpts + if err := decodeJSON(r, &opts); err != nil { + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) { + httpErr(w, err, 413) + } else { + httpErr(w, err, 400) + } + return + } + if opts.Depth > maxGraphDepth { + httpErr(w, fmt.Errorf("depth exceeds maximum of %d", maxGraphDepth), 400) + return + } + result, err := s.eng.RecallWithConfidence(r.Context(), opts) + if err != nil { + httpErr(w, err, 500) + return + } + httpJSONCapped(w, result, 200) +} + +// --- Session compression handler --- + +func (s *RESTServer) handleSessionCompress(w http.ResponseWriter, r *http.Request) { + var body struct { + SessionID string `json:"session_id"` + } + if err := decodeJSON(r, &body); err != nil { + httpErr(w, err, 400) + return + } + if body.SessionID == "" { + httpErr(w, fmt.Errorf("session_id is required"), 400) + return + } + n, err := s.eng.CompressSessionEvents(r.Context(), body.SessionID) + if err != nil { + httpErr(w, err, 500) + return + } + httpJSON(w, map[string]any{"compressed": n}, 200) +} + +// --- Agent file bridge handlers --- + +func (s *RESTServer) handleBridgeImport(w http.ResponseWriter, r *http.Request) { + if s.projectDir == "" { + httpErr(w, fmt.Errorf("no project directory configured"), 400) + return + } + bridge := engine.NewAgentFileBridge(s.projectDir) + rules, err := bridge.Import() + if err != nil { + httpErr(w, err, 500) + return + } + httpJSON(w, map[string]any{"rules": rules, "count": len(rules)}, 200) +} + +func (s *RESTServer) handleBridgeExport(w http.ResponseWriter, r *http.Request) { + if s.projectDir == "" { + httpErr(w, fmt.Errorf("no project directory configured"), 400) + return + } + var body struct { + Rules []engine.AgentRule `json:"rules"` + FileType engine.AgentFileType `json:"file_type"` + } + if err := decodeJSON(r, &body); err != nil { + httpErr(w, err, 400) + return + } + if body.FileType == "" { + body.FileType = engine.FileClaudeMD + } + bridge := engine.NewAgentFileBridge(s.projectDir) + if err := bridge.Export(body.Rules, body.FileType); err != nil { + httpErr(w, err, 500) + return + } + httpJSON(w, map[string]string{"status": "exported"}, 200) +} + +func (s *RESTServer) handleBridgeSync(w http.ResponseWriter, r *http.Request) { + if s.projectDir == "" { + httpErr(w, fmt.Errorf("no project directory configured"), 400) + return + } + var body struct { + Conventions []string `json:"conventions"` + } + if err := decodeJSON(r, &body); err != nil { + httpErr(w, err, 400) + return + } + bridge := engine.NewAgentFileBridge(s.projectDir) + if err := bridge.Sync(body.Conventions); err != nil { + httpErr(w, err, 500) + return + } + httpJSON(w, map[string]string{"status": "synced"}, 200) +} + +// --- Advanced feature handlers --- + +func (s *RESTServer) handleCommunities(w http.ResponseWriter, r *http.Request) { + cd := engine.NewCommunityDetector(s.eng.Store()) + communities, err := cd.Detect(r.Context(), 10) + if err != nil { + httpErr(w, err, 500) + return + } + communities = cd.Summarize(r.Context(), communities) + httpJSON(w, communities, 200) +} + +func (s *RESTServer) handleHierarchy(w http.ResponseWriter, r *http.Request) { + hm := engine.NewHierarchicalMemory(s.eng.Store()) + if err := hm.Build(r.Context()); err != nil { + httpErr(w, err, 500) + return + } + + var req struct { + Query string `json:"query"` + Level int `json:"level"` + } + req.Level = -1 + _ = decodeJSON(r, &req) + + if req.Level < 0 && req.Query != "" { + req.Level = hm.RetrieveAdaptive(req.Query) + } + if req.Level < 0 { + req.Level = 1 + } + + if req.Query != "" { + clusters := hm.RetrieveAtLevel(req.Query, req.Level) + httpJSON(w, map[string]interface{}{ + "level": req.Level, "query": req.Query, "clusters": clusters, + }, 200) + return + } + + httpJSON(w, map[string]interface{}{ + "level": req.Level, "content": hm.FormatLevel(req.Level), + }, 200) +} + +func (s *RESTServer) handleSparsify(w http.ResponseWriter, r *http.Request) { + if !s.rateLimit(w) { + return + } + sp := engine.NewSparsifier(s.eng.Store()) + result, err := sp.Run(r.Context()) + if err != nil { + httpErr(w, err, 500) + return + } + httpJSON(w, map[string]interface{}{ + "merged": result.Merged, "compressed": result.Compressed, + "pruned": result.Pruned, "total": result.Merged + result.Compressed + result.Pruned, + }, 200) +} + +func (s *RESTServer) handleVerify(w http.ResponseWriter, r *http.Request) { + yaadDir := filepath.Join(s.projectDir, ".yaad") + if s.projectDir == "" { + yaadDir = ".yaad" + } + mi, err := engine.NewMemoryIntegrity(yaadDir) + if err != nil { + httpErr(w, err, 500) + return + } + nodes, err := s.eng.Store().ListNodes(r.Context(), storage.NodeFilter{Limit: 10000}) + if err != nil { + httpErr(w, err, 500) + return + } + storedSigs, err := s.eng.Store().GetAllSignatures(r.Context()) + if err != nil { + httpErr(w, err, 500) + return + } + tampered := mi.VerifyBatch(nodes, storedSigs) + status := "ok" + if len(tampered) > 0 { + status = "tampered" + } + httpJSON(w, map[string]interface{}{ + "status": status, "nodes_verified": len(nodes), "tampered": len(tampered), + }, 200) +} + +func (s *RESTServer) handleEntities(w http.ResponseWriter, r *http.Request) { + // Build entity index and return stats + idx := engine.NewEntityIndex() + nodes, err := s.eng.Store().ListNodes(r.Context(), storage.NodeFilter{Limit: 5000}) + if err != nil { + httpErr(w, err, 500) + return + } + for _, n := range nodes { + idx.IndexNode(n) + } + httpJSON(w, map[string]interface{}{ + "unique_entities": idx.Size(), + "nodes_indexed": len(nodes), + }, 200) +} + +func (s *RESTServer) handleDoctor(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + store := s.eng.Store() + + // Graph health — use SQL aggregation instead of loading all nodes + types, _, _ := store.NodeStats(ctx) + ds, _ := store.DoctorStats(ctx) + totalNodes := ds.TotalNodes + orphanCount := ds.Orphans + lowConfCount := ds.LowConfidence + pinnedCount := ds.Pinned + staleCount := 0 + totalEdges, _ := store.CountAllEdges(ctx) + coverage := 0 + if types["convention"] > 0 { + coverage += 20 + } + if types["decision"] > 0 { + coverage += 20 + } + if types["spec"] > 0 { + coverage += 20 + } + if types["task"] > 0 { + coverage += 15 + } + if types["bug"] > 0 { + coverage += 15 + } + if types["preference"] > 0 { + coverage += 10 + } + + // Recommendations + var recs []string + if orphanCount > totalNodes/4 { + recs = append(recs, "High orphan count — run 'yaad sparsify' to clean up disconnected memories") + } + if lowConfCount > totalNodes/3 { + recs = append(recs, "Many low-confidence memories — run 'yaad decay' and 'yaad gc'") + } + if types["convention"] == 0 { + recs = append(recs, "No conventions stored — teach your agent coding rules") + } + if types["decision"] == 0 { + recs = append(recs, "No decisions stored — record architecture choices") + } + if types["spec"] == 0 { + recs = append(recs, "No specs stored — document subsystem designs") + } + if pinnedCount == 0 { + recs = append(recs, "No pinned memories — pin critical conventions with 'yaad pin '") + } + if len(recs) == 0 { + recs = append(recs, "Memory graph looks healthy!") + } + + httpJSON(w, map[string]interface{}{ + "nodes": totalNodes, + "edges": totalEdges, + "orphans": orphanCount, + "low_confidence": lowConfCount, + "pinned": pinnedCount, + "stale": staleCount, + "coverage_score": coverage, + "types": types, + "recommendations": recs, + }, 200) +} diff --git a/internal/server/rest_handlers.go b/internal/server/rest_handlers.go new file mode 100644 index 0000000..a4dd4e8 --- /dev/null +++ b/internal/server/rest_handlers.go @@ -0,0 +1,511 @@ +// This file is part of package server. It holds the core REST request +// handlers (memory, graph, node, session, and maintenance endpoints) +// split verbatim out of rest.go for readability; behavior is unchanged. + +package server + +import ( + "errors" + "fmt" + "log/slog" + "net/http" + "time" + + "github.com/GrayCodeAI/yaad/embeddings" + "github.com/GrayCodeAI/yaad/engine" + gitwatch "github.com/GrayCodeAI/yaad/git" + "github.com/GrayCodeAI/yaad/graph" + "github.com/GrayCodeAI/yaad/internal/telemetry" + "github.com/GrayCodeAI/yaad/internal/version" + "github.com/GrayCodeAI/yaad/storage" + "github.com/google/uuid" +) + +func (s *RESTServer) handleRemember(w http.ResponseWriter, r *http.Request) { + var in engine.RememberInput + if err := decodeJSON(r, &in); err != nil { + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) { + httpErr(w, err, 413) + } else { + httpErr(w, err, 400) + } + return + } + if in.Type != "" && !engine.IsValidNodeType(in.Type) { + httpErr(w, fmt.Errorf("invalid node type: %q", in.Type), 400) + return + } + node, err := s.eng.Remember(r.Context(), in) + if err != nil { + httpErr(w, err, 500) + return + } + httpJSON(w, node, 201) +} + +func (s *RESTServer) handleRecall(w http.ResponseWriter, r *http.Request) { + var opts engine.RecallOpts + if err := decodeJSON(r, &opts); err != nil { + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) { + httpErr(w, err, 413) + } else { + httpErr(w, err, 400) + } + return + } + if opts.Depth > maxGraphDepth { + httpErr(w, fmt.Errorf("depth exceeds maximum of %d", maxGraphDepth), 400) + return + } + result, err := s.eng.Recall(r.Context(), opts) + if err != nil { + httpErr(w, err, 500) + return + } + httpJSONCapped(w, result, 200) +} + +func (s *RESTServer) handleContext(w http.ResponseWriter, r *http.Request) { + project := r.URL.Query().Get("project") + result, err := s.eng.Context(r.Context(), project) + if err != nil { + httpErr(w, err, 500) + return + } + httpJSONCapped(w, result, 200) +} + +func (s *RESTServer) handleLink(w http.ResponseWriter, r *http.Request) { + var edge storage.Edge + if err := decodeJSON(r, &edge); err != nil { + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) { + httpErr(w, err, 413) + } else { + httpErr(w, err, 400) + } + return + } + if edge.Type == "" { + httpErr(w, fmt.Errorf("edge type is required"), 400) + return + } + if !graph.IsValidEdgeType(edge.Type) { + httpErr(w, fmt.Errorf("invalid edge type: %q", edge.Type), 400) + return + } + if edge.ID == "" { + edge.ID = uuid.New().String() + } + if err := s.eng.Graph().AddEdge(r.Context(), &edge); err != nil { + httpErr(w, err, 400) + return + } + httpJSON(w, edge, 201) +} + +// rateLimit returns 429 if the rate limiter rejects the request. +func (s *RESTServer) rateLimit(w http.ResponseWriter) bool { + if s.limiter != nil && !s.limiter.Allow() { + httpErr(w, fmt.Errorf("rate limit exceeded, try again later"), 429) + return false + } + return true +} + +func (s *RESTServer) handleDeleteLink(w http.ResponseWriter, r *http.Request) { + if !s.rateLimit(w) { + return + } + id := r.PathValue("id") + if err := s.eng.Graph().RemoveEdge(r.Context(), id); err != nil { + httpErr(w, err, 404) + return + } + httpJSON(w, map[string]string{"status": "deleted"}, 200) +} + +func (s *RESTServer) handleGetNode(w http.ResponseWriter, r *http.Request) { + start := time.Now() + id := r.PathValue("id") + node, err := s.eng.Store().GetNode(r.Context(), id) + telemetry.MemoryRetrieveDuration.Record(r.Context(), time.Since(start).Seconds()) + if err != nil { + httpErr(w, err, 404) + return + } + neighbors, _ := s.eng.Store().GetNeighbors(r.Context(), id) + httpJSON(w, map[string]any{"node": node, "neighbors": neighbors}, 200) +} + +func (s *RESTServer) handleSubgraph(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + depth := intQuery(r, "depth", 2) + if depth > maxGraphDepth { + httpErr(w, fmt.Errorf("depth exceeds maximum of %d", maxGraphDepth), 400) + return + } + sg, err := s.eng.Graph().ExtractSubgraph(r.Context(), id, depth) + if err != nil { + httpErr(w, err, 500) + return + } + httpJSONCapped(w, sg, 200) +} + +func (s *RESTServer) handleImpact(w http.ResponseWriter, r *http.Request) { + file := r.PathValue("file") + depth := intQuery(r, "depth", 3) + if depth > maxGraphDepth { + httpErr(w, fmt.Errorf("depth exceeds maximum of %d", maxGraphDepth), 400) + return + } + ids, err := s.eng.Graph().Impact(r.Context(), file, depth) + if err != nil { + httpErr(w, err, 500) + return + } + var nodes []*storage.Node + for _, id := range ids { + if n, err := s.eng.Store().GetNode(r.Context(), id); err == nil { + nodes = append(nodes, n) + } + } + httpJSONCapped(w, nodes, 200) +} + +func (s *RESTServer) handleForget(w http.ResponseWriter, r *http.Request) { + if !s.rateLimit(w) { + return + } + id := r.PathValue("id") + if err := s.eng.Forget(r.Context(), id); err != nil { + httpErr(w, err, 404) + return + } + httpJSON(w, map[string]string{"status": "forgotten"}, 200) +} + +func (s *RESTServer) handleUpdateNode(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + node, err := s.eng.Store().GetNode(r.Context(), id) + if err != nil { + httpErr(w, err, 404) + return + } + + var patch struct { + Content *string `json:"content"` + Summary *string `json:"summary"` + Tags *string `json:"tags"` + Key *string `json:"key"` + Pinned *bool `json:"pinned"` + Type *string `json:"type"` + Tier *int `json:"tier"` + } + if err := decodeJSON(r, &patch); err != nil { + httpErr(w, err, 400) + return + } + + // Save version before modifying + if patch.Content != nil && *patch.Content != node.Content { + _ = s.eng.Store().SaveVersion(r.Context(), node.ID, node.Content, "api", "updated via PATCH") + } + + if patch.Content != nil { + node.Content = *patch.Content + node.ContentHash = engine.ContentHash(node.Content, node.Scope, node.Project, node.Type) + } + if patch.Summary != nil { + node.Summary = *patch.Summary + } + if patch.Tags != nil { + node.Tags = *patch.Tags + } + if patch.Key != nil { + node.Key = *patch.Key + } + if patch.Pinned != nil { + node.Pinned = *patch.Pinned + } + if patch.Type != nil { + if !engine.IsValidNodeType(*patch.Type) { + httpErr(w, fmt.Errorf("invalid node type: %q", *patch.Type), 400) + return + } + node.Type = *patch.Type + } + if patch.Tier != nil { + node.Tier = *patch.Tier + } + node.Version++ + if err := s.eng.Store().UpdateNode(r.Context(), node); err != nil { + httpErr(w, err, 500) + return + } + httpJSON(w, node, 200) +} + +func (s *RESTServer) handlePinNode(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + node, err := s.eng.Store().GetNode(r.Context(), id) + if err != nil { + httpErr(w, err, 404) + return + } + node.Pinned = !node.Pinned + if err := s.eng.Store().UpdateNode(r.Context(), node); err != nil { + httpErr(w, err, 500) + return + } + httpJSON(w, map[string]any{"id": node.ID, "pinned": node.Pinned}, 200) +} + +func (s *RESTServer) handleHealth(w http.ResponseWriter, r *http.Request) { + // Actually verify database connectivity with a lightweight query + _, err := s.eng.Store().ListNodes(r.Context(), storage.NodeFilter{}) + if err != nil { + httpJSON(w, map[string]string{"status": "error", "error": err.Error()}, 503) + return + } + httpJSON(w, map[string]string{"status": "ok", "version": version.String()}, 200) +} + +func (s *RESTServer) handleVersion(w http.ResponseWriter, _ *http.Request) { + httpJSON(w, map[string]string{"version": version.String()}, 200) +} + +// handleLiveness responds to Kubernetes liveness probes (/healthz). +// Always returns 200 — if the process is alive, it is live. +func (s *RESTServer) handleLiveness(w http.ResponseWriter, _ *http.Request) { + httpJSON(w, map[string]string{"status": "ok"}, 200) +} + +// handleReadiness responds to Kubernetes readiness probes (/readyz). +// Delegates to the full health check to confirm the database is reachable. +func (s *RESTServer) handleReadiness(w http.ResponseWriter, r *http.Request) { + s.handleHealth(w, r) +} + +func (s *RESTServer) handleStats(w http.ResponseWriter, r *http.Request) { + project := r.URL.Query().Get("project") + st, err := s.eng.Status(r.Context(), project) + if err != nil { + httpErr(w, err, 500) + return + } + httpJSON(w, st, 200) +} + +func (s *RESTServer) handleSessions(w http.ResponseWriter, r *http.Request) { + project := r.URL.Query().Get("project") + limit := intQuery(r, "limit", 10) + sessions, err := s.eng.Store().ListSessions(r.Context(), project, limit) + if err != nil { + httpErr(w, err, 500) + return + } + httpJSON(w, sessions, 200) +} + +func (s *RESTServer) handleSessionStart(w http.ResponseWriter, r *http.Request) { + var body struct { + Project string `json:"project"` + Agent string `json:"agent"` + } + if err := decodeJSON(r, &body); err != nil { + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) { + httpErr(w, err, 413) + } else { + httpErr(w, err, 400) + } + return + } + sess := &storage.Session{ + ID: uuid.New().String(), + Project: body.Project, + Agent: body.Agent, + StartedAt: time.Now(), + } + if err := s.eng.Store().CreateSession(r.Context(), sess); err != nil { + httpErr(w, err, 500) + return + } + ctxRes, err := s.eng.Context(r.Context(), body.Project) + if err != nil { + slog.Warn("session start: context load failed", "error", err) + } + httpJSON(w, map[string]any{"session": sess, "context": ctxRes}, 201) +} + +func (s *RESTServer) handleSessionEnd(w http.ResponseWriter, r *http.Request) { + var body struct { + ID string `json:"id"` + Summary string `json:"summary"` + } + if err := decodeJSON(r, &body); err != nil { + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) { + httpErr(w, err, 413) + } else { + httpErr(w, err, 400) + } + return + } + if err := s.eng.Store().EndSession(r.Context(), body.ID, body.Summary); err != nil { + httpErr(w, err, 500) + return + } + httpJSON(w, map[string]string{"status": "ended"}, 200) +} + +func (s *RESTServer) handleStale(w http.ResponseWriter, r *http.Request) { + if s.projectDir == "" { + httpJSON(w, map[string]string{"status": "no project directory configured"}, 200) + return + } + // gitwatch.New is lightweight (path validation + struct creation only). + watcher, err := gitwatch.New(s.eng.Store(), s.eng.Graph(), s.projectDir) + if err != nil { + httpErr(w, err, 500) + return + } + since := time.Now().Add(-7 * 24 * time.Hour) // last 7 days + reports, err := watcher.StalesSince(r.Context(), since) + if err != nil { + httpErr(w, err, 500) + return + } + httpJSON(w, reports, 200) +} + +func (s *RESTServer) handleEmbed(w http.ResponseWriter, r *http.Request) { + var body struct { + NodeID string `json:"node_id"` + } + if err := decodeJSON(r, &body); err != nil { + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) { + httpErr(w, err, 413) + } else { + httpErr(w, err, 400) + } + return + } + if s.embedder == nil { + httpErr(w, fmt.Errorf("no embedding provider configured"), 503) + return + } + node, err := s.eng.Store().GetNode(r.Context(), body.NodeID) + if err != nil { + httpErr(w, err, 404) + return + } + // Document mode: stored content is embedded as a document (search_document) + // so it pairs correctly with query-mode embeddings at search time. + vec, err := s.embedder.EmbedWithMode(r.Context(), node.Content, embeddings.ModeDocument) + if err != nil { + httpErr(w, err, 500) + return + } + if err := s.eng.Store().SaveEmbedding(r.Context(), node.ID, s.embedder.Name(), vec); err != nil { + httpErr(w, err, 500) + return + } + httpJSON(w, map[string]any{"node_id": node.ID, "dims": len(vec)}, 200) +} + +func (s *RESTServer) handleHybridRecall(w http.ResponseWriter, r *http.Request) { + var opts engine.RecallOpts + if err := decodeJSON(r, &opts); err != nil { + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) { + httpErr(w, err, 413) + } else { + httpErr(w, err, 400) + } + return + } + hs := engine.NewHybridSearch(s.eng.Store(), s.eng.Graph(), s.embedder) + scored, err := hs.Search(r.Context(), opts.Query, opts) + if err != nil { + httpErr(w, err, 500) + return + } + reranked := engine.Rerank(r.Context(), scored, s.eng.Store()) + httpJSONCapped(w, reranked, 200) +} + +func (s *RESTServer) handleProactive(w http.ResponseWriter, r *http.Request) { + project := r.URL.Query().Get("project") + hs := engine.NewHybridSearch(s.eng.Store(), s.eng.Graph(), s.embedder) + pc := engine.NewProactiveContext(s.eng, hs) + nodes, err := pc.Predict(r.Context(), project, 2000) + if err != nil { + httpErr(w, err, 500) + return + } + httpJSON(w, map[string]any{ + "nodes": nodes, + "context": engine.FormatContext(nodes), + }, 200) +} + +func (s *RESTServer) handleFeedback(w http.ResponseWriter, r *http.Request) { + var body struct { + ID string `json:"id"` + Action engine.FeedbackAction `json:"action"` + NewContent string `json:"new_content"` + } + if err := decodeJSON(r, &body); err != nil { + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) { + httpErr(w, err, 413) + } else { + httpErr(w, err, 400) + } + return + } + if err := s.eng.Feedback(r.Context(), body.ID, body.Action, body.NewContent); err != nil { + httpErr(w, err, 400) + return + } + httpJSON(w, map[string]string{"status": "ok"}, 200) +} + +func (s *RESTServer) handleDecay(w http.ResponseWriter, r *http.Request) { + if !s.rateLimit(w) { + return + } + if err := engine.RunDecay(r.Context(), s.eng.Store(), s.eng.DecayConfig); err != nil { + httpErr(w, err, 500) + return + } + httpJSON(w, map[string]string{"status": "decay applied"}, 200) +} + +func (s *RESTServer) handleGC(w http.ResponseWriter, r *http.Request) { + if !s.rateLimit(w) { + return + } + n, err := engine.GarbageCollect(r.Context(), s.eng.Store(), s.eng.DecayConfig) + if err != nil { + httpErr(w, err, 500) + return + } + httpJSON(w, map[string]int{"removed": n}, 200) +} + +func (s *RESTServer) handleReplay(w http.ResponseWriter, r *http.Request) { + sessionID := r.PathValue("session_id") + events, err := s.eng.Store().GetReplayEvents(r.Context(), sessionID) + if err != nil { + httpErr(w, err, 500) + return + } + httpJSON(w, events, 200) +} diff --git a/storage/sqlite.go b/storage/sqlite.go index 94cd6bb..e967fbe 100644 --- a/storage/sqlite.go +++ b/storage/sqlite.go @@ -12,14 +12,15 @@ import ( "sync" "time" - "github.com/GrayCodeAI/yaad/internal/telemetry" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/metric" - sqlite3 "modernc.org/sqlite" sqlite3lib "modernc.org/sqlite/lib" ) +// Note: node, edge, and transaction storage operations were moved +// verbatim into sqlite_nodes.go, sqlite_edges.go, and sqlite_tx.go for +// readability. This file keeps shared infra (schema, busy-retry, stmt +// cache, sessions, file-watch, and scan helpers). + // isSQLITEBusy reports whether err wraps an SQLITE_BUSY (code 5) error. // // It type-asserts the driver's *sqlite.Error rather than matching on the error @@ -507,540 +508,6 @@ CREATE TRIGGER IF NOT EXISTS nodes_ad AFTER DELETE ON nodes BEGIN END; ` -// --- Nodes --- - -func (s *Store) CreateNode(ctx context.Context, n *Node) error { - start := time.Now() - ctx, cancel := s.withTimeout(ctx) - defer cancel() - err := retryOnBusy(func() error { - return createNodeQ(ctx, s.q(), n) - }, 5, 2*time.Millisecond) - attrs := attribute.NewSet(attribute.String("op", "create_node")) - telemetry.SQLiteQueryDuration.Record(ctx, time.Since(start).Seconds(), metric.WithAttributeSet(attrs)) - telemetry.SQLiteQueryCount.Add(ctx, 1, metric.WithAttributeSet(attrs)) - return err -} - -func createNodeQ(ctx context.Context, q queryable, n *Node) error { - _, err := q.ExecContext(ctx, `INSERT INTO nodes (id, type, content, content_hash, summary, scope, project, tier, tags, key, pinned, confidence, access_count, created_at, updated_at, accessed_at, source_session, source_agent, version) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, - n.ID, n.Type, n.Content, n.ContentHash, n.Summary, n.Scope, n.Project, n.Tier, n.Tags, nullString(n.Key), n.Pinned, n.Confidence, n.AccessCount, - n.CreatedAt, n.UpdatedAt, nullTime(n.AccessedAt), n.SourceSession, n.SourceAgent, n.Version) - if err != nil && strings.Contains(err.Error(), "UNIQUE constraint failed") { - return fmt.Errorf("%w: %s", ErrDuplicateNode, err) - } - if err != nil { - return err - } - // Persist structured metadata. - return saveNodeMetadataQ(ctx, q, n.ID, n.Metadata) -} - -// saveNodeMetadataQ replaces all metadata for a node, inserting rows for each -// key-value pair. Uses the same queryable (pooled connection or transaction). -func saveNodeMetadataQ(ctx context.Context, q queryable, nodeID string, meta map[string]string) error { - if len(meta) == 0 { - return nil - } - // Delete existing metadata for this node (upsert semantics). - if _, err := q.ExecContext(ctx, `DELETE FROM node_metadata WHERE node_id = ?`, nodeID); err != nil { - return err - } - for k, v := range meta { - if _, err := q.ExecContext(ctx, - `INSERT INTO node_metadata (node_id, key, value) VALUES (?, ?, ?)`, nodeID, k, v); err != nil { - return err - } - } - return nil -} - -// GetNode retrieves a node by its primary key ID. -// Returns ErrNodeNotFound wrapped with the ID when no row is found. -func (s *Store) GetNode(ctx context.Context, id string) (*Node, error) { - start := time.Now() - n, err := retryOnBusyVal(func() (*Node, error) { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - return getNodeQ(ctx, s.q(), id) - }, 5, 50*time.Millisecond) - attrs := attribute.NewSet(attribute.String("op", "get_node")) - telemetry.SQLiteQueryDuration.Record(ctx, time.Since(start).Seconds(), metric.WithAttributeSet(attrs)) - telemetry.SQLiteQueryCount.Add(ctx, 1, metric.WithAttributeSet(attrs)) - _ = err // metrics always recorded - return n, err -} - -func getNodeQ(ctx context.Context, q queryable, id string) (*Node, error) { - n := &Node{} - var accessedAt sql.NullTime - var key sql.NullString - err := q.QueryRowContext(ctx, `SELECT id, type, content, content_hash, summary, scope, project, tier, tags, key, pinned, confidence, access_count, created_at, updated_at, accessed_at, source_session, source_agent, version FROM nodes WHERE id = ?`, id). - Scan(&n.ID, &n.Type, &n.Content, &n.ContentHash, &n.Summary, &n.Scope, &n.Project, &n.Tier, &n.Tags, &key, &n.Pinned, &n.Confidence, &n.AccessCount, &n.CreatedAt, &n.UpdatedAt, &accessedAt, &n.SourceSession, &n.SourceAgent, &n.Version) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, fmt.Errorf("%w: %s", ErrNodeNotFound, id) - } - return nil, err - } - if accessedAt.Valid { - n.AccessedAt = accessedAt.Time - } - if key.Valid { - n.Key = key.String - } - return n, nil -} - -// GetNodeByKey looks up a node by its unique key within a project. -// Returns (nil, nil) when no matching node is found (upsert check pattern). -func (s *Store) GetNodeByKey(ctx context.Context, key, project string) (*Node, error) { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - return getNodeByKeyQ(ctx, s.q(), key, project) -} - -func getNodeByKeyQ(ctx context.Context, q queryable, key, project string) (*Node, error) { - n := &Node{} - var accessedAt sql.NullTime - var k sql.NullString - err := q.QueryRowContext(ctx, `SELECT id, type, content, content_hash, summary, scope, project, tier, tags, key, pinned, confidence, access_count, created_at, updated_at, accessed_at, source_session, source_agent, version FROM nodes WHERE key = ? AND project = ?`, key, project). - Scan(&n.ID, &n.Type, &n.Content, &n.ContentHash, &n.Summary, &n.Scope, &n.Project, &n.Tier, &n.Tags, &k, &n.Pinned, &n.Confidence, &n.AccessCount, &n.CreatedAt, &n.UpdatedAt, &accessedAt, &n.SourceSession, &n.SourceAgent, &n.Version) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, nil - } - return nil, err - } - if accessedAt.Valid { - n.AccessedAt = accessedAt.Time - } - if k.Valid { - n.Key = k.String - } - return n, nil -} - -func (s *Store) UpdateNode(ctx context.Context, n *Node) error { - start := time.Now() - err := retryOnBusy(func() error { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - return updateNodeQ(ctx, s.q(), n) - }, 5, 50*time.Millisecond) - attrs := attribute.NewSet(attribute.String("op", "update_node")) - telemetry.SQLiteQueryDuration.Record(ctx, time.Since(start).Seconds(), metric.WithAttributeSet(attrs)) - telemetry.SQLiteQueryCount.Add(ctx, 1, metric.WithAttributeSet(attrs)) - return err -} - -func updateNodeQ(ctx context.Context, q queryable, n *Node) error { - _, err := q.ExecContext(ctx, `UPDATE nodes SET type=?, content=?, content_hash=?, summary=?, scope=?, project=?, tier=?, tags=?, key=?, pinned=?, confidence=?, access_count=?, updated_at=?, accessed_at=?, source_session=?, source_agent=?, version=? WHERE id=?`, - n.Type, n.Content, n.ContentHash, n.Summary, n.Scope, n.Project, n.Tier, n.Tags, nullString(n.Key), n.Pinned, n.Confidence, n.AccessCount, - n.UpdatedAt, nullTime(n.AccessedAt), n.SourceSession, n.SourceAgent, n.Version, n.ID) - if err != nil { - return err - } - // Persist structured metadata (delete + insert). - return saveNodeMetadataQ(ctx, q, n.ID, n.Metadata) -} - -func (s *Store) UpdateNodeContent(ctx context.Context, id, newContent string) error { - return retryOnBusy(func() error { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - return updateNodeContentQ(ctx, s.q(), id, newContent) - }, 5, 50*time.Millisecond) -} - -func updateNodeContentQ(ctx context.Context, q queryable, id, newContent string) error { - _, err := q.ExecContext(ctx, `UPDATE nodes SET content=?, updated_at=CURRENT_TIMESTAMP WHERE id=?`, newContent, id) - return err -} - -func (s *Store) DeleteNode(ctx context.Context, id string) error { - return retryOnBusy(func() error { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return err - } - defer func() { _ = tx.Rollback() }() - if err := deleteNodeQ(ctx, tx, id); err != nil { - return err - } - return tx.Commit() - }, 5, 50*time.Millisecond) -} - -func deleteNodeQ(ctx context.Context, q queryable, id string) error { - if _, err := q.ExecContext(ctx, `DELETE FROM edges WHERE from_id=? OR to_id=?`, id, id); err != nil { - return err - } - _, err := q.ExecContext(ctx, `DELETE FROM nodes WHERE id=?`, id) - return err -} - -func (s *Store) ListNodes(ctx context.Context, f NodeFilter) ([]*Node, error) { - start := time.Now() - nodes, err := retryOnBusyVal(func() ([]*Node, error) { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - return listNodesQ(ctx, s.q(), f) - }, 5, 50*time.Millisecond) - attrs := attribute.NewSet(attribute.String("op", "list_nodes")) - telemetry.SQLiteQueryDuration.Record(ctx, time.Since(start).Seconds(), metric.WithAttributeSet(attrs)) - telemetry.SQLiteQueryCount.Add(ctx, 1, metric.WithAttributeSet(attrs)) - return nodes, err -} - -func listNodesQ(ctx context.Context, q queryable, f NodeFilter) ([]*Node, error) { - query := "SELECT id, type, content, content_hash, summary, scope, project, tier, tags, key, pinned, confidence, access_count, created_at, updated_at, accessed_at, source_session, source_agent, version FROM nodes WHERE 1=1" - var args []any - if f.Type != "" { - query += " AND type=?" - args = append(args, f.Type) - } - if f.Scope != "" { - query += " AND scope=?" - args = append(args, f.Scope) - } - if f.Project != "" { - query += " AND project=?" - args = append(args, f.Project) - } - if f.Tier > 0 { - query += " AND tier=?" - args = append(args, f.Tier) - } - if f.MinConfidence > 0 { - query += " AND confidence>=?" - args = append(args, f.MinConfidence) - } - if f.SourceSession != "" { - query += " AND source_session=?" - args = append(args, f.SourceSession) - } - if f.Pinned != nil { - query += " AND pinned=?" - args = append(args, *f.Pinned) - } - if f.Tag != "" { - // Delimiter-aware exact tag match: wrap both the stored CSV and the - // target in commas so "topic:foo" does not match "topic:foobar". - // '%' and '_' in the tag are escaped so they are treated literally. - esc := strings.NewReplacer(`\`, `\\`, `%`, `\%`, `_`, `\_`).Replace(f.Tag) - query += ` AND (',' || tags || ',') LIKE ? ESCAPE '\'` - args = append(args, "%,"+esc+",%") - } - // Metadata key-value filters (AND semantics — all must match). - for k, v := range f.MetadataFilters { - query += ` AND EXISTS (SELECT 1 FROM node_metadata nm WHERE nm.node_id = nodes.id AND nm.key = ? AND nm.value = ?)` - args = append(args, k, v) - } - query += " LIMIT ?" - if f.Limit > 0 { - args = append(args, f.Limit) - } else { - args = append(args, 1000) - } - if f.Offset > 0 { - query += " OFFSET ?" - args = append(args, f.Offset) - } - rows, err := q.QueryContext(ctx, query, args...) - if err != nil { - return nil, err - } - defer func() { _ = rows.Close() }() - return scanNodes(rows) -} - -// escapeFTS5 escapes special FTS5 characters by wrapping each token in double -// quotes and escaping embedded quotes. This prevents FTS5 query injection via -// operators like *, -, AND, OR, NOT. -func escapeFTS5(query string) string { - words := strings.Fields(query) - for i, w := range words { - // Escape embedded quotes by doubling them, then wrap in quotes - w = strings.ReplaceAll(w, `"`, `""`) - words[i] = `"` + w + `"` - } - return strings.Join(words, " OR ") -} - -func (s *Store) SearchNodes(ctx context.Context, query string, limit int) ([]*Node, error) { - start := time.Now() - nodes, err := retryOnBusyVal(func() ([]*Node, error) { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - return searchNodesQ(ctx, s.q(), query, limit) - }, 5, 50*time.Millisecond) - attrs := attribute.NewSet(attribute.String("op", "search_nodes")) - telemetry.SQLiteQueryDuration.Record(ctx, time.Since(start).Seconds(), metric.WithAttributeSet(attrs)) - telemetry.SQLiteQueryCount.Add(ctx, 1, metric.WithAttributeSet(attrs)) - return nodes, err -} - -func searchNodesQ(ctx context.Context, q queryable, query string, limit int) ([]*Node, error) { - if limit <= 0 { - limit = 10 - } - ftsQuery := escapeFTS5(query) - rows, err := q.QueryContext(ctx, `SELECT n.id, n.type, n.content, n.content_hash, n.summary, n.scope, n.project, n.tier, n.tags, n.key, n.pinned, n.confidence, n.access_count, n.created_at, n.updated_at, n.accessed_at, n.source_session, n.source_agent, n.version - FROM nodes_fts f JOIN nodes n ON f.rowid = n.rowid WHERE nodes_fts MATCH ? ORDER BY rank LIMIT ?`, ftsQuery, limit) - if err != nil { - return nil, err - } - defer func() { _ = rows.Close() }() - return scanNodes(rows) -} - -// --- Edges --- - -func (s *Store) CreateEdge(ctx context.Context, e *Edge) error { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - // Retry on SQLITE_BUSY: SelfLink runs outside the engine's write lock and - // competes for the SQLite write lock with concurrent graph operations. - return retryOnBusy(func() error { - return createEdgeQ(ctx, s.q(), e) - }, 5, 2*time.Millisecond) -} - -func createEdgeQ(ctx context.Context, q queryable, e *Edge) error { - // Temporal validity: an edge becomes valid the moment it is created unless - // the caller explicitly supplied a valid_at. This keeps valid_at "live" for - // every insert path without requiring callers to set it. CreatedAt is also - // defaulted to now so the columns are never NULL for new rows. - now := time.Now().UTC() - if e.ValidAt.IsZero() { - e.ValidAt = now - } - if e.CreatedAt.IsZero() { - e.CreatedAt = now - } - _, err := q.ExecContext(ctx, `INSERT INTO edges (id, from_id, to_id, type, acyclic, weight, metadata, valid_at, invalid_at, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, - e.ID, e.FromID, e.ToID, e.Type, e.Acyclic, e.Weight, e.Metadata, nullTime(e.ValidAt), nullTime(e.InvalidAt), e.CreatedAt) - if err != nil && strings.Contains(err.Error(), "UNIQUE constraint failed") { - return fmt.Errorf("%w: %s", ErrDuplicateEdge, err) - } - return err -} - -// GetEdge retrieves an edge by its primary key ID. -// Returns ErrEdgeNotFound wrapped with the ID when no row is found. -func (s *Store) GetEdge(ctx context.Context, id string) (*Edge, error) { - return retryOnBusyVal(func() (*Edge, error) { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - return getEdgeQ(ctx, s.q(), id) - }, 5, 50*time.Millisecond) -} - -func getEdgeQ(ctx context.Context, q queryable, id string) (*Edge, error) { - e := &Edge{} - var validAt, invalidAt sql.NullTime - err := q.QueryRowContext(ctx, `SELECT id, from_id, to_id, type, acyclic, weight, metadata, valid_at, invalid_at, created_at FROM edges WHERE id=?`, id). - Scan(&e.ID, &e.FromID, &e.ToID, &e.Type, &e.Acyclic, &e.Weight, &e.Metadata, &validAt, &invalidAt, &e.CreatedAt) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, fmt.Errorf("%w: %s", ErrEdgeNotFound, id) - } - return nil, err - } - if validAt.Valid { - e.ValidAt = validAt.Time - } - if invalidAt.Valid { - e.InvalidAt = invalidAt.Time - } - return e, nil -} - -func (s *Store) InvalidateEdge(ctx context.Context, id string) error { - return retryOnBusy(func() error { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - return invalidateEdgeQ(ctx, s.q(), id) - }, 5, 50*time.Millisecond) -} - -func invalidateEdgeQ(ctx context.Context, q queryable, id string) error { - _, err := q.ExecContext(ctx, `UPDATE edges SET invalid_at = ? WHERE id = ?`, time.Now().UTC(), id) - return err -} - -func (s *Store) DeleteEdge(ctx context.Context, id string) error { - return retryOnBusy(func() error { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - return deleteEdgeQ(ctx, s.q(), id) - }, 5, 50*time.Millisecond) -} - -func deleteEdgeQ(ctx context.Context, q queryable, id string) error { - _, err := q.ExecContext(ctx, `DELETE FROM edges WHERE id=?`, id) - return err -} - -func (s *Store) GetEdgesFrom(ctx context.Context, nodeID string) ([]*Edge, error) { - return retryOnBusyVal(func() ([]*Edge, error) { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - return queryEdgesQ(ctx, s.q(), `SELECT id, from_id, to_id, type, acyclic, weight, metadata, valid_at, invalid_at, created_at FROM edges WHERE from_id=? AND invalid_at IS NULL`, nodeID) - }, 5, 50*time.Millisecond) -} - -func (s *Store) GetEdgesTo(ctx context.Context, nodeID string) ([]*Edge, error) { - return retryOnBusyVal(func() ([]*Edge, error) { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - return queryEdgesQ(ctx, s.q(), `SELECT id, from_id, to_id, type, acyclic, weight, metadata, valid_at, invalid_at, created_at FROM edges WHERE to_id=? AND invalid_at IS NULL`, nodeID) - }, 5, 50*time.Millisecond) -} - -// GetValidEdgesFrom returns outbound edges from nodeID that were valid at the -// given instant: valid_at <= at AND (invalid_at IS NULL OR invalid_at > at). -// A zero `at` defaults to now, giving a "currently valid" view. This is a -// point-in-time variant of GetEdgesFrom; the latter is left unchanged so -// existing callers keep their current behavior. -func (s *Store) GetValidEdgesFrom(ctx context.Context, nodeID string, at time.Time) ([]*Edge, error) { - at = validityInstant(at) - return retryOnBusyVal(func() ([]*Edge, error) { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - return queryEdgesQ(ctx, s.q(), `SELECT id, from_id, to_id, type, acyclic, weight, metadata, valid_at, invalid_at, created_at FROM edges - WHERE from_id=? AND (valid_at IS NULL OR valid_at <= ?) AND (invalid_at IS NULL OR invalid_at > ?)`, nodeID, at, at) - }, 5, 50*time.Millisecond) -} - -// GetValidEdgesTo is the inbound counterpart of GetValidEdgesFrom. -func (s *Store) GetValidEdgesTo(ctx context.Context, nodeID string, at time.Time) ([]*Edge, error) { - at = validityInstant(at) - return retryOnBusyVal(func() ([]*Edge, error) { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - return queryEdgesQ(ctx, s.q(), `SELECT id, from_id, to_id, type, acyclic, weight, metadata, valid_at, invalid_at, created_at FROM edges - WHERE to_id=? AND (valid_at IS NULL OR valid_at <= ?) AND (invalid_at IS NULL OR invalid_at > ?)`, nodeID, at, at) - }, 5, 50*time.Millisecond) -} - -// validityInstant normalizes a point-in-time argument: a zero time means "now". -func validityInstant(at time.Time) time.Time { - if at.IsZero() { - return time.Now().UTC() - } - return at.UTC() -} - -func queryEdgesQ(ctx context.Context, q queryable, query string, args ...any) ([]*Edge, error) { - rows, err := q.QueryContext(ctx, query, args...) - if err != nil { - return nil, err - } - defer func() { _ = rows.Close() }() - return scanEdges(rows) -} - -func (s *Store) GetEdgesBetween(ctx context.Context, nodeIDs []string) ([]*Edge, error) { - return retryOnBusyVal(func() ([]*Edge, error) { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - return getEdgesBetweenQ(ctx, s.q(), nodeIDs) - }, 5, 50*time.Millisecond) -} - -func getEdgesBetweenQ(ctx context.Context, q queryable, nodeIDs []string) ([]*Edge, error) { - if len(nodeIDs) == 0 { - return nil, nil - } - var all []*Edge - for i := 0; i < len(nodeIDs); i += maxSQLVariables { - end := i + maxSQLVariables - if end > len(nodeIDs) { - end = len(nodeIDs) - } - chunk := nodeIDs[i:end] - placeholders := make([]string, len(chunk)) - args := make([]any, 0, len(chunk)*2) - for j, id := range chunk { - placeholders[j] = "?" - args = append(args, id) - } - ph := strings.Join(placeholders, ",") - query := fmt.Sprintf(`SELECT id, from_id, to_id, type, acyclic, weight, metadata, valid_at, invalid_at, created_at FROM edges - WHERE from_id IN (%s) AND to_id IN (%s)`, ph, ph) - args = append(args, args...) - edges, err := queryEdgesQ(ctx, q, query, args...) - if err != nil { - return nil, err - } - all = append(all, edges...) - } - return all, nil -} - -// GetAllEdgesFor returns all edges where from_id OR to_id is in the given set. -// Used by IntentBFS for batch edge retrieval (avoids N+1 queries). -func (s *Store) GetAllEdgesFor(ctx context.Context, nodeIDs []string) ([]*Edge, error) { - return retryOnBusyVal(func() ([]*Edge, error) { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - return getAllEdgesForQ(ctx, s.q(), nodeIDs) - }, 5, 50*time.Millisecond) -} - -func getAllEdgesForQ(ctx context.Context, q queryable, nodeIDs []string) ([]*Edge, error) { - if len(nodeIDs) == 0 { - return nil, nil - } - var all []*Edge - for i := 0; i < len(nodeIDs); i += maxSQLVariables { - end := i + maxSQLVariables - if end > len(nodeIDs) { - end = len(nodeIDs) - } - chunk := nodeIDs[i:end] - placeholders := make([]string, len(chunk)) - args := make([]any, 0, len(chunk)*2) - for j, id := range chunk { - placeholders[j] = "?" - args = append(args, id) - } - ph := strings.Join(placeholders, ",") - query := fmt.Sprintf(`SELECT id, from_id, to_id, type, acyclic, weight, metadata, valid_at, invalid_at, created_at FROM edges - WHERE from_id IN (%s) OR to_id IN (%s)`, ph, ph) - args = append(args, args[:len(chunk)]...) // second set of args for to_id - edges, err := queryEdgesQ(ctx, q, query, args...) - if err != nil { - return nil, err - } - all = append(all, edges...) - } - return all, nil -} - -func (s *Store) GetNeighbors(ctx context.Context, nodeID string) ([]*Node, error) { - return retryOnBusyVal(func() ([]*Node, error) { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - return getNeighborsQ(ctx, s.q(), nodeID) - }, 5, 50*time.Millisecond) -} - -func getNeighborsQ(ctx context.Context, q queryable, nodeID string) ([]*Node, error) { - rows, err := q.QueryContext(ctx, `SELECT DISTINCT n.id, n.type, n.content, n.content_hash, n.summary, n.scope, n.project, n.tier, n.tags, n.key, n.pinned, n.confidence, n.access_count, n.created_at, n.updated_at, n.accessed_at, n.source_session, n.source_agent, n.version - FROM nodes n JOIN edges e ON (e.to_id = n.id AND e.from_id = ?) OR (e.from_id = n.id AND e.to_id = ?)`, nodeID, nodeID) - if err != nil { - return nil, err - } - defer func() { _ = rows.Close() }() - return scanNodes(rows) -} - // --- Sessions --- func (s *Store) CreateSession(ctx context.Context, sess *Session) error { @@ -1139,67 +606,13 @@ func (s *Store) GetNodesByFile(ctx context.Context, filePath string) ([]*Node, e }, 5, 50*time.Millisecond) } -// --- Versions --- +// --- Helpers --- -func (s *Store) SaveVersion(ctx context.Context, nodeID string, content, changedBy, reason string) error { - return retryOnBusy(func() error { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return err - } - defer func() { _ = tx.Rollback() }() - if err := saveVersionQ(ctx, tx, nodeID, content, changedBy, reason); err != nil { - return err - } - return tx.Commit() - }, 5, 50*time.Millisecond) -} - -func saveVersionQ(ctx context.Context, q queryable, nodeID string, content, changedBy, reason string) error { - var maxVer int - err := q.QueryRowContext(ctx, `SELECT COALESCE(MAX(version), 0) FROM node_versions WHERE node_id=?`, nodeID).Scan(&maxVer) - if err != nil { - return err - } - _, err = q.ExecContext(ctx, `INSERT INTO node_versions (node_id, version, content, changed_at, changed_by, reason) VALUES (?, ?, ?, ?, ?, ?)`, - nodeID, maxVer+1, content, time.Now().UTC(), changedBy, reason) - return err -} - -func (s *Store) GetVersions(ctx context.Context, nodeID string) ([]*NodeVersion, error) { - return retryOnBusyVal(func() ([]*NodeVersion, error) { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - return getVersionsQ(ctx, s.q(), nodeID) - }, 5, 50*time.Millisecond) -} - -func getVersionsQ(ctx context.Context, q queryable, nodeID string) ([]*NodeVersion, error) { - rows, err := q.QueryContext(ctx, `SELECT node_id, version, content, changed_at, changed_by, reason FROM node_versions WHERE node_id=? ORDER BY version`, nodeID) - if err != nil { - return nil, err - } - defer func() { _ = rows.Close() }() - var out []*NodeVersion - for rows.Next() { - v := &NodeVersion{} - if err := rows.Scan(&v.NodeID, &v.Version, &v.Content, &v.ChangedAt, &v.ChangedBy, &v.Reason); err != nil { - return nil, err - } - out = append(out, v) - } - return out, rows.Err() -} - -// --- Helpers --- - -func nullTime(t time.Time) sql.NullTime { - if t.IsZero() { - return sql.NullTime{} - } - return sql.NullTime{Time: t, Valid: true} +func nullTime(t time.Time) sql.NullTime { + if t.IsZero() { + return sql.NullTime{} + } + return sql.NullTime{Time: t, Valid: true} } func nullString(s string) sql.NullString { @@ -1251,765 +664,3 @@ func scanEdges(rows *sql.Rows) ([]*Edge, error) { // maxSQLVariables is the maximum number of SQLite host parameters per query. // SQLite default is 999; we stay well under it for safety. const maxSQLVariables = 900 - -// --- AccessLog --- - -// LogAccess records a lightweight access event (INSERT only, no UPDATE churn). -func (s *Store) LogAccess(ctx context.Context, nodeID string) error { - return retryOnBusy(func() error { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - return logAccessQ(ctx, s.q(), nodeID) - }, 5, 50*time.Millisecond) -} - -func logAccessQ(ctx context.Context, q queryable, nodeID string) error { - _, err := q.ExecContext(ctx, `INSERT INTO access_log (node_id) VALUES (?)`, nodeID) - return err -} - -// FlushAccessLog aggregates access_log entries into nodes.access_count / accessed_at, -// then truncates the log. Runs atomically within a transaction. -func (s *Store) FlushAccessLog(ctx context.Context) (int, error) { - return retryOnBusyVal(func() (int, error) { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return 0, err - } - defer func() { _ = tx.Rollback() }() - n, err := flushAccessLogQ(ctx, tx) - if err != nil { - return 0, err - } - return n, tx.Commit() - }, 5, 50*time.Millisecond) -} - -func flushAccessLogQ(ctx context.Context, q queryable) (int, error) { - rows, err := q.QueryContext(ctx, ` - SELECT node_id, COUNT(*) as cnt, MAX(created_at) as last_at - FROM access_log - GROUP BY node_id`) - if err != nil { - return 0, err - } - defer func() { _ = rows.Close() }() - type agg struct { - nodeID string - count int - lastAt time.Time - } - var aggs []agg - for rows.Next() { - var a agg - var lastAtStr string - if err := rows.Scan(&a.nodeID, &a.count, &lastAtStr); err != nil { - return 0, err - } - if lastAtStr != "" { - t, _ := time.Parse(time.RFC3339Nano, lastAtStr) - if t.IsZero() { - t, _ = time.Parse("2006-01-02 15:04:05", lastAtStr) - } - a.lastAt = t - } - aggs = append(aggs, a) - } - if err := rows.Err(); err != nil { - return 0, err - } - if len(aggs) == 0 { - return 0, nil - } - for _, a := range aggs { - _, err := q.ExecContext(ctx, ` - UPDATE nodes - SET access_count = access_count + ?, - accessed_at = MAX(COALESCE(accessed_at, '1970-01-01'), ?) - WHERE id = ?`, - a.count, a.lastAt, a.nodeID) - if err != nil { - return 0, err - } - } - _, err = q.ExecContext(ctx, `DELETE FROM access_log`) - if err != nil { - return 0, err - } - return len(aggs), nil -} - -// --- Metadata --- - -func (s *Store) SaveNodeMetadata(ctx context.Context, nodeID string, meta map[string]string) error { - return retryOnBusy(func() error { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - return saveNodeMetadataQ(ctx, s.q(), nodeID, meta) - }, 5, 50*time.Millisecond) -} - -func (s *Store) LoadNodeMetadata(ctx context.Context, nodeIDs []string) (map[string]map[string]string, error) { - if len(nodeIDs) == 0 { - return nil, nil - } - ids, args := chunkedArgs(nodeIDs) - return retryOnBusyVal(func() (map[string]map[string]string, error) { - return loadNodeMetadataQ(ctx, s.q(), ids, args) - }, 2, 10*time.Millisecond) -} - -func loadNodeMetadataQ(ctx context.Context, q queryable, idsChunk []string, args []any) (map[string]map[string]string, error) { - query := "SELECT node_id, key, value FROM node_metadata WHERE node_id IN (?" - for i := 1; i < len(idsChunk); i++ { - query += ", ?" - } - query += ")" - rows, err := q.QueryContext(ctx, query, args...) - if err != nil { - return nil, err - } - defer func() { _ = rows.Close() }() - - result := make(map[string]map[string]string) - for rows.Next() { - var nodeID, key, value string - if err := rows.Scan(&nodeID, &key, &value); err != nil { - return nil, err - } - if result[nodeID] == nil { - result[nodeID] = make(map[string]string) - } - result[nodeID][key] = value - } - return result, rows.Err() -} - -func chunkedArgs(ids []string) ([]string, []any) { - args := make([]any, len(ids)) - for i, id := range ids { - args[i] = id - } - return ids, args -} - -// FillNodeMetadata loads metadata for all given nodes in one batch query and -// populates each node's Metadata field in-place. No-op for an empty slice. -func (s *Store) FillNodeMetadata(ctx context.Context, nodes []*Node) error { - if len(nodes) == 0 { - return nil - } - ids := make([]string, len(nodes)) - for i, n := range nodes { - ids[i] = n.ID - } - meta, err := s.LoadNodeMetadata(ctx, ids) - if err != nil { - return err - } - for _, n := range nodes { - n.Metadata = meta[n.ID] - } - return nil -} - -// --- Signatures --- - -func (s *Store) SaveSignature(ctx context.Context, nodeID, signature string) error { - return retryOnBusy(func() error { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - return saveSignatureQ(ctx, s.q(), nodeID, signature) - }, 5, 50*time.Millisecond) -} - -func saveSignatureQ(ctx context.Context, q queryable, nodeID, signature string) error { - _, err := q.ExecContext(ctx, - `INSERT OR REPLACE INTO node_signatures (node_id, signature, signed_at) VALUES (?, ?, CURRENT_TIMESTAMP)`, - nodeID, signature) - return err -} - -func (s *Store) GetAllSignatures(ctx context.Context) (map[string]string, error) { - return retryOnBusyVal(func() (map[string]string, error) { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - return getAllSignaturesQ(ctx, s.db) - }, 5, 50*time.Millisecond) -} - -func getAllSignaturesQ(ctx context.Context, q queryable) (map[string]string, error) { - rows, err := q.QueryContext(ctx, `SELECT node_id, signature FROM node_signatures`) - if err != nil { - return nil, err - } - defer func() { _ = rows.Close() }() - out := make(map[string]string) - for rows.Next() { - var nodeID, sig string - if err := rows.Scan(&nodeID, &sig); err != nil { - return nil, err - } - out[nodeID] = sig - } - return out, rows.Err() -} - -func (s *Store) GetNodesBatch(ctx context.Context, ids []string) ([]*Node, error) { - return retryOnBusyVal(func() ([]*Node, error) { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - return getNodesBatchQ(ctx, s.q(), ids) - }, 5, 50*time.Millisecond) -} - -func getNodesBatchQ(ctx context.Context, q queryable, ids []string) ([]*Node, error) { - if len(ids) == 0 { - return nil, nil - } - var all []*Node - for i := 0; i < len(ids); i += maxSQLVariables { - end := i + maxSQLVariables - if end > len(ids) { - end = len(ids) - } - chunk := ids[i:end] - placeholders := make([]string, len(chunk)) - args := make([]any, len(chunk)) - for j, id := range chunk { - placeholders[j] = "?" - args[j] = id - } - query := fmt.Sprintf(`SELECT id, type, content, content_hash, summary, scope, project, tier, tags, key, pinned, confidence, access_count, created_at, updated_at, accessed_at, source_session, source_agent, version FROM nodes WHERE id IN (%s)`, strings.Join(placeholders, ",")) - rows, err := q.QueryContext(ctx, query, args...) - if err != nil { - return nil, err - } - nodes, err := scanNodes(rows) - _ = rows.Close() - if err != nil { - return nil, err - } - all = append(all, nodes...) - } - return all, nil -} - -// CountEdges returns inbound and outbound edge counts for a node. -func (s *Store) CountEdges(ctx context.Context, nodeID string) (inbound int, outbound int, err error) { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - return countEdgesQ(ctx, s.q(), nodeID) -} - -func countEdgesQ(ctx context.Context, q queryable, nodeID string) (int, int, error) { - var inbound, outbound int - err := q.QueryRowContext(ctx, - `SELECT (SELECT COUNT(*) FROM edges WHERE to_id = ?), (SELECT COUNT(*) FROM edges WHERE from_id = ?)`, - nodeID, nodeID).Scan(&inbound, &outbound) - return inbound, outbound, err -} - -// CountAllEdges returns the total number of edges in the graph. -func (s *Store) CountAllEdges(ctx context.Context) (int, error) { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - return countAllEdgesQ(ctx, s.db) -} - -func countAllEdgesQ(ctx context.Context, q queryable) (int, error) { - var count int - err := q.QueryRowContext(ctx, `SELECT COUNT(*) FROM edges`).Scan(&count) - return count, err -} - -// CountEdgesBatch returns inbound/outbound edge counts for multiple nodes in a single query. -func (s *Store) CountEdgesBatch(ctx context.Context, nodeIDs []string) (map[string][2]int, error) { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - return countEdgesBatchQ(ctx, s.q(), nodeIDs) -} - -func countEdgesBatchQ(ctx context.Context, q queryable, nodeIDs []string) (map[string][2]int, error) { - if len(nodeIDs) == 0 { - return nil, nil - } - result := make(map[string][2]int, len(nodeIDs)) - for _, id := range nodeIDs { - result[id] = [2]int{0, 0} - } - - // Count outbound edges (from_id IN (...)) - outQ := `SELECT from_id, COUNT(*) FROM edges WHERE from_id IN (` + placeholders(len(nodeIDs)) + `) GROUP BY from_id` - args := make([]any, len(nodeIDs)) - for i, id := range nodeIDs { - args[i] = id - } - rows, err := q.QueryContext(ctx, outQ, args...) - if err != nil { - return nil, fmt.Errorf("count edges batch outbound: %w", err) - } - defer func() { _ = rows.Close() }() - for rows.Next() { - var id string - var count int - if err := rows.Scan(&id, &count); err != nil { - return nil, err - } - if v, ok := result[id]; ok { - v[1] = count - result[id] = v - } - } - if err := rows.Err(); err != nil { - return nil, err - } - - // Count inbound edges (to_id IN (...)) - inQ := `SELECT to_id, COUNT(*) FROM edges WHERE to_id IN (` + placeholders(len(nodeIDs)) + `) GROUP BY to_id` - rows2, err := q.QueryContext(ctx, inQ, args...) - if err != nil { - return nil, fmt.Errorf("count edges batch inbound: %w", err) - } - defer func() { _ = rows2.Close() }() - for rows2.Next() { - var id string - var count int - if err := rows2.Scan(&id, &count); err != nil { - return nil, err - } - if v, ok := result[id]; ok { - v[0] = count - result[id] = v - } - } - if err := rows2.Err(); err != nil { - return nil, err - } - - return result, nil -} - -func placeholders(n int) string { - if n <= 0 { - return "" - } - b := make([]byte, 0, n*2-1) - for i := 0; i < n; i++ { - if i > 0 { - b = append(b, ',') - } - b = append(b, '?') - } - return string(b) -} - -// NodeStats returns nodes grouped by type and total count. -func (s *Store) NodeStats(ctx context.Context) (map[string]int, int, error) { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - return nodeStatsQ(ctx, s.db) -} - -func nodeStatsQ(ctx context.Context, q queryable) (map[string]int, int, error) { - rows, err := q.QueryContext(ctx, `SELECT type, COUNT(*) FROM nodes GROUP BY type`) - if err != nil { - return nil, 0, err - } - defer func() { _ = rows.Close() }() - stats := make(map[string]int) - total := 0 - for rows.Next() { - var typ string - var cnt int - if err := rows.Scan(&typ, &cnt); err != nil { - return nil, 0, err - } - stats[typ] = cnt - total += cnt - } - return stats, total, rows.Err() -} - -// DoctorStats returns aggregated diagnostic counts via SQL without loading -// individual nodes into memory. -func (s *Store) DoctorStats(ctx context.Context) (DoctorStatsResult, error) { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - return doctorStatsQ(ctx, s.db) -} - -func doctorStatsQ(ctx context.Context, q queryable) (DoctorStatsResult, error) { - var r DoctorStatsResult - - err := q.QueryRowContext(ctx, ` - SELECT - COUNT(*), - SUM(CASE WHEN confidence < 0.2 THEN 1 ELSE 0 END), - SUM(CASE WHEN pinned = 1 THEN 1 ELSE 0 END), - SUM(CASE WHEN NOT EXISTS ( - SELECT 1 FROM edges ef WHERE ef.from_id = nodes.id - UNION ALL - SELECT 1 FROM edges et WHERE et.to_id = nodes.id - LIMIT 1 - ) THEN 1 ELSE 0 END) - FROM nodes - `).Scan(&r.TotalNodes, &r.LowConfidence, &r.Pinned, &r.Orphans) - if err != nil { - return DoctorStatsResult{}, err - } - return r, nil -} - -// LastUpdated returns the most recent updated_at time across all nodes. -func (s *Store) LastUpdated(ctx context.Context) (time.Time, error) { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - return lastUpdatedQ(ctx, s.db) -} - -func lastUpdatedQ(ctx context.Context, q queryable) (time.Time, error) { - // updated_at is stored as a text timestamp; the sqlite driver returns - // MAX(updated_at) as a string, so scan into a NullString and parse it the - // same way the rest of this file does (RFC3339Nano with a legacy fallback). - // Scanning directly into sql.NullTime fails the driver's type conversion - // whenever a row exists. - var s sql.NullString - if err := q.QueryRowContext(ctx, `SELECT MAX(updated_at) FROM nodes`).Scan(&s); err != nil { - return time.Time{}, err - } - if !s.Valid || s.String == "" { - return time.Time{}, nil - } - // The sqlite driver stores a time.Time as its String() form - // ("2006-01-02 15:04:05.999999999 -0700 MST"); also accept RFC3339 and the - // legacy "2006-01-02 15:04:05" layout for robustness across writers. - for _, layout := range []string{"2006-01-02 15:04:05.999999999 -0700 MST", time.RFC3339Nano, "2006-01-02 15:04:05"} { - if t, err := time.Parse(layout, s.String); err == nil { - return t, nil - } - } - return time.Time{}, nil -} - -// TopConnected returns the content of the most-connected nodes by edge count. -func (s *Store) TopConnected(ctx context.Context, limit int) ([]string, error) { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - return topConnectedQ(ctx, s.q(), limit) -} - -func topConnectedQ(ctx context.Context, q queryable, limit int) ([]string, error) { - rows, err := q.QueryContext(ctx, `SELECT n.content, COUNT(*) as cnt - FROM edges e JOIN nodes n ON n.id = e.from_id OR n.id = e.to_id - GROUP BY n.id ORDER BY cnt DESC LIMIT ?`, limit) - if err != nil { - return nil, err - } - defer func() { _ = rows.Close() }() - var out []string - for rows.Next() { - var content string - var cnt int - if err := rows.Scan(&content, &cnt); err != nil { - return nil, err - } - out = append(out, content) - } - return out, rows.Err() -} - -// CheckCycle uses a recursive CTE to detect if adding from->to would create a cycle -// among acyclic edges. -func (s *Store) CheckCycle(ctx context.Context, fromID, toID string) (bool, error) { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - var result bool - err := retryOnBusy(func() error { - var innerErr error - result, innerErr = checkCycleQ(ctx, s.q(), fromID, toID) - return innerErr - }, 5, 2*time.Millisecond) - return result, err -} - -func checkCycleQ(ctx context.Context, q queryable, fromID, toID string) (bool, error) { - // UNION (not UNION ALL) deduplicates the frontier: if a cycle ever slips - // into the acyclic edge set, UNION ALL would recurse forever, whereas UNION - // terminates once every reachable ancestor has been visited. - query := ` - WITH RECURSIVE ancestors(id) AS ( - SELECT ? - UNION - SELECT e.from_id FROM ancestors a - JOIN edges e ON e.to_id = a.id AND e.acyclic = 1 - ) - SELECT 1 FROM ancestors WHERE id = ? LIMIT 1` - var exists int - err := q.QueryRowContext(ctx, query, fromID, toID).Scan(&exists) - if err == nil { - return true, nil - } - if err == sql.ErrNoRows { - return false, nil - } - return false, err -} - -// WithTx runs the given function inside a SQL transaction. -// If the function returns an error, the transaction is rolled back. -// -// The whole transaction is retried on SQLITE_BUSY. A transaction acquires the -// single SQLite writer more eagerly than a lone statement, so it can lose a -// race with the async ingestion goroutine; without retry, correctness-critical -// transactional writes (e.g. the atomic cycle-check+insert in AddEdge) would -// fail spuriously under concurrency. Re-running fn is safe because a busy error -// occurs before commit, so no partial state is visible. -func (s *Store) WithTx(ctx context.Context, fn func(Storage) error) error { - return retryOnBusy(func() error { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return err - } - defer func() { _ = tx.Rollback() }() - - txStore := &txStore{tx: tx} - if err := fn(txStore); err != nil { - return err - } - return tx.Commit() - }, 5, 50*time.Millisecond) -} - -// txStore is a Storage implementation backed by a SQL transaction. -type txStore struct { - tx *sql.Tx -} - -// txStore is a thin wrapper that delegates all operations to shared *Q functions. -func (t *txStore) CreateNode(ctx context.Context, n *Node) error { return createNodeQ(ctx, t.tx, n) } - -func (t *txStore) GetNode(ctx context.Context, id string) (*Node, error) { - return getNodeQ(ctx, t.tx, id) -} - -func (t *txStore) GetNodeByKey(ctx context.Context, key, project string) (*Node, error) { - return getNodeByKeyQ(ctx, t.tx, key, project) -} - -func (t *txStore) GetNodesBatch(ctx context.Context, ids []string) ([]*Node, error) { - return getNodesBatchQ(ctx, t.tx, ids) -} - -func (t *txStore) UpdateNode(ctx context.Context, n *Node) error { return updateNodeQ(ctx, t.tx, n) } - -func (t *txStore) UpdateNodeContent(ctx context.Context, id, newContent string) error { - return updateNodeContentQ(ctx, t.tx, id, newContent) -} - -func (t *txStore) DeleteNode(ctx context.Context, id string) error { return deleteNodeQ(ctx, t.tx, id) } - -func (t *txStore) ListNodes(ctx context.Context, f NodeFilter) ([]*Node, error) { - return listNodesQ(ctx, t.tx, f) -} - -func (t *txStore) SearchNodes(ctx context.Context, query string, limit int) ([]*Node, error) { - return searchNodesQ(ctx, t.tx, query, limit) -} - -func (t *txStore) SearchNodeByHash(ctx context.Context, hash, scope, project string) (*Node, error) { - return searchNodeByHashQ(ctx, t.tx, hash, scope, project) -} - -func (t *txStore) GetNeighbors(ctx context.Context, nodeID string) ([]*Node, error) { - return getNeighborsQ(ctx, t.tx, nodeID) -} - -func (t *txStore) CreateEdge(ctx context.Context, e *Edge) error { return createEdgeQ(ctx, t.tx, e) } - -func (t *txStore) GetEdge(ctx context.Context, id string) (*Edge, error) { - return getEdgeQ(ctx, t.tx, id) -} - -func (t *txStore) InvalidateEdge(ctx context.Context, id string) error { - return invalidateEdgeQ(ctx, t.tx, id) -} - -func (t *txStore) DeleteEdge(ctx context.Context, id string) error { return deleteEdgeQ(ctx, t.tx, id) } - -func (t *txStore) GetEdgesFrom(ctx context.Context, nodeID string) ([]*Edge, error) { - return queryEdgesQ(ctx, t.tx, `SELECT id, from_id, to_id, type, acyclic, weight, metadata, valid_at, invalid_at, created_at FROM edges WHERE from_id=? AND invalid_at IS NULL`, nodeID) -} - -func (t *txStore) GetEdgesTo(ctx context.Context, nodeID string) ([]*Edge, error) { - return queryEdgesQ(ctx, t.tx, `SELECT id, from_id, to_id, type, acyclic, weight, metadata, valid_at, invalid_at, created_at FROM edges WHERE to_id=? AND invalid_at IS NULL`, nodeID) -} - -func (t *txStore) GetValidEdgesFrom(ctx context.Context, nodeID string, at time.Time) ([]*Edge, error) { - at = validityInstant(at) - return queryEdgesQ(ctx, t.tx, `SELECT id, from_id, to_id, type, acyclic, weight, metadata, valid_at, invalid_at, created_at FROM edges - WHERE from_id=? AND (valid_at IS NULL OR valid_at <= ?) AND (invalid_at IS NULL OR invalid_at > ?)`, nodeID, at, at) -} - -func (t *txStore) GetValidEdgesTo(ctx context.Context, nodeID string, at time.Time) ([]*Edge, error) { - at = validityInstant(at) - return queryEdgesQ(ctx, t.tx, `SELECT id, from_id, to_id, type, acyclic, weight, metadata, valid_at, invalid_at, created_at FROM edges - WHERE to_id=? AND (valid_at IS NULL OR valid_at <= ?) AND (invalid_at IS NULL OR invalid_at > ?)`, nodeID, at, at) -} - -func (t *txStore) GetEdgesBetween(ctx context.Context, nodeIDs []string) ([]*Edge, error) { - return getEdgesBetweenQ(ctx, t.tx, nodeIDs) -} - -func (t *txStore) GetAllEdgesFor(ctx context.Context, nodeIDs []string) ([]*Edge, error) { - return getAllEdgesForQ(ctx, t.tx, nodeIDs) -} - -func (t *txStore) CountEdges(ctx context.Context, nodeID string) (int, int, error) { - return countEdgesQ(ctx, t.tx, nodeID) -} - -func (t *txStore) CountAllEdges(ctx context.Context) (int, error) { return countAllEdgesQ(ctx, t.tx) } - -func (t *txStore) CountEdgesBatch(ctx context.Context, nodeIDs []string) (map[string][2]int, error) { - return countEdgesBatchQ(ctx, t.tx, nodeIDs) -} - -func (t *txStore) NodeStats(ctx context.Context) (map[string]int, int, error) { - return nodeStatsQ(ctx, t.tx) -} - -func (t *txStore) DoctorStats(ctx context.Context) (DoctorStatsResult, error) { - return doctorStatsQ(ctx, t.tx) -} - -func (t *txStore) LastUpdated(ctx context.Context) (time.Time, error) { return lastUpdatedQ(ctx, t.tx) } - -func (t *txStore) TopConnected(ctx context.Context, limit int) ([]string, error) { - return topConnectedQ(ctx, t.tx, limit) -} - -func (t *txStore) CheckCycle(ctx context.Context, fromID, toID string) (bool, error) { - return checkCycleQ(ctx, t.tx, fromID, toID) -} - -func (t *txStore) CreateSession(ctx context.Context, sess *Session) error { - return createSessionQ(ctx, t.tx, sess) -} - -func (t *txStore) EndSession(ctx context.Context, id string, summary string) error { - return endSessionQ(ctx, t.tx, id, summary) -} - -func (t *txStore) ListSessions(ctx context.Context, project string, limit int) ([]*Session, error) { - return listSessionsQ(ctx, t.tx, project, limit) -} - -func (t *txStore) SaveVersion(ctx context.Context, nodeID string, content, changedBy, reason string) error { - return saveVersionQ(ctx, t.tx, nodeID, content, changedBy, reason) -} - -func (t *txStore) GetVersions(ctx context.Context, nodeID string) ([]*NodeVersion, error) { - return getVersionsQ(ctx, t.tx, nodeID) -} - -func (t *txStore) SaveEmbedding(ctx context.Context, nodeID, model string, vector []float32) error { - return saveEmbeddingQ(ctx, t.tx, nodeID, model, vector) -} - -func (t *txStore) DeleteEmbedding(ctx context.Context, nodeID string) error { - return deleteEmbeddingQ(ctx, t.tx, nodeID) -} - -func (t *txStore) GetEmbedding(ctx context.Context, nodeID string) ([]float32, string, error) { - return getEmbeddingQ(ctx, t.tx, nodeID) -} - -func (t *txStore) AllEmbeddings(ctx context.Context, model string) (map[string][]float32, error) { - return allEmbeddingsQ(ctx, t.tx, model) -} - -func (t *txStore) GetEmbeddingsBatch(ctx context.Context, model string, offset, limit int) (map[string][]float32, error) { - return getEmbeddingsBatchQ(ctx, t.tx, model, offset, limit) -} - -func (t *txStore) AddFileWatch(ctx context.Context, filePath, nodeID, gitHash string) error { - return addFileWatchQ(ctx, t.tx, filePath, nodeID, gitHash) -} - -func (t *txStore) AddReplayEvent(ctx context.Context, sessionID, data string) error { - return addReplayEventQ(ctx, t.tx, sessionID, data) -} - -func (t *txStore) GetReplayEvents(ctx context.Context, sessionID string) ([]*ReplayEvent, error) { - return getReplayEventsQ(ctx, t.tx, sessionID) -} - -func (t *txStore) LogAccess(ctx context.Context, nodeID string) error { - return logAccessQ(ctx, t.tx, nodeID) -} - -func (t *txStore) SaveNodeMetadata(ctx context.Context, nodeID string, meta map[string]string) error { - return saveNodeMetadataQ(ctx, t.tx, nodeID, meta) -} - -func (t *txStore) LoadNodeMetadata(ctx context.Context, nodeIDs []string) (map[string]map[string]string, error) { - ids, args := chunkedArgs(nodeIDs) - return loadNodeMetadataQ(ctx, t.tx, ids, args) -} - -func (t *txStore) SaveSignature(ctx context.Context, nodeID, signature string) error { - return saveSignatureQ(ctx, t.tx, nodeID, signature) -} - -func (t *txStore) GetAllSignatures(ctx context.Context) (map[string]string, error) { - return getAllSignaturesQ(ctx, t.tx) -} - -func (t *txStore) FlushAccessLog(ctx context.Context) (int, error) { return flushAccessLogQ(ctx, t.tx) } -func (t *txStore) WithTx(ctx context.Context, fn func(Storage) error) error { return fn(t) } -func (t *txStore) Close() error { return nil } - -// RollbackToVersion restores a node's content to a specific version. -// All operations run inside a single transaction to prevent concurrent -// modifications from interleaving between the SELECT and UPDATE. -func (s *Store) RollbackToVersion(ctx context.Context, nodeID string, version int) error { - return s.WithTx(ctx, func(tx Storage) error { - // GetVersions is not on Storage interface, so use GetNode + version query via txStore. - // We need the raw version content. Use the versions listing through txStore. - var content string - // txStore implements Storage; cast to access version query - if tts, ok := tx.(*txStore); ok { - err := tts.tx.QueryRowContext(ctx, - `SELECT content FROM node_versions WHERE node_id=? AND version=?`, nodeID, version).Scan(&content) - if err != nil { - return fmt.Errorf("version %d not found for node %s: %w", version, nodeID, err) - } - } else { - return fmt.Errorf("unexpected storage type in transaction") - } - if err := tx.UpdateNodeContent(ctx, nodeID, content); err != nil { - return err - } - return tx.SaveVersion(ctx, nodeID, content, "system", fmt.Sprintf("rollback to v%d", version)) - }) -} - -// DiffVersions returns the content of two versions for comparison. -func (s *Store) DiffVersions(ctx context.Context, nodeID string, v1, v2 int) (content1, content2 string, err error) { - ctx, cancel := s.withTimeout(ctx) - defer cancel() - err = s.db.QueryRowContext(ctx, - `SELECT content FROM node_versions WHERE node_id=? AND version=?`, nodeID, v1).Scan(&content1) - if err != nil { - return "", "", fmt.Errorf("version %d not found: %w", v1, err) - } - err = s.db.QueryRowContext(ctx, - `SELECT content FROM node_versions WHERE node_id=? AND version=?`, nodeID, v2).Scan(&content2) - if err != nil { - return "", "", fmt.Errorf("version %d not found: %w", v2, err) - } - return content1, content2, nil -} diff --git a/storage/sqlite_edges.go b/storage/sqlite_edges.go new file mode 100644 index 0000000..3859ef6 --- /dev/null +++ b/storage/sqlite_edges.go @@ -0,0 +1,518 @@ +// This file is part of package storage. It holds edge, graph-statistics, +// and cycle-check storage operations split verbatim out of sqlite.go for +// readability; behavior is unchanged. + +package storage + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + "time" +) + +// --- Edges --- + +func (s *Store) CreateEdge(ctx context.Context, e *Edge) error { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + // Retry on SQLITE_BUSY: SelfLink runs outside the engine's write lock and + // competes for the SQLite write lock with concurrent graph operations. + return retryOnBusy(func() error { + return createEdgeQ(ctx, s.q(), e) + }, 5, 2*time.Millisecond) +} + +func createEdgeQ(ctx context.Context, q queryable, e *Edge) error { + // Temporal validity: an edge becomes valid the moment it is created unless + // the caller explicitly supplied a valid_at. This keeps valid_at "live" for + // every insert path without requiring callers to set it. CreatedAt is also + // defaulted to now so the columns are never NULL for new rows. + now := time.Now().UTC() + if e.ValidAt.IsZero() { + e.ValidAt = now + } + if e.CreatedAt.IsZero() { + e.CreatedAt = now + } + _, err := q.ExecContext(ctx, `INSERT INTO edges (id, from_id, to_id, type, acyclic, weight, metadata, valid_at, invalid_at, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + e.ID, e.FromID, e.ToID, e.Type, e.Acyclic, e.Weight, e.Metadata, nullTime(e.ValidAt), nullTime(e.InvalidAt), e.CreatedAt) + if err != nil && strings.Contains(err.Error(), "UNIQUE constraint failed") { + return fmt.Errorf("%w: %s", ErrDuplicateEdge, err) + } + return err +} + +// GetEdge retrieves an edge by its primary key ID. +// Returns ErrEdgeNotFound wrapped with the ID when no row is found. +func (s *Store) GetEdge(ctx context.Context, id string) (*Edge, error) { + return retryOnBusyVal(func() (*Edge, error) { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + return getEdgeQ(ctx, s.q(), id) + }, 5, 50*time.Millisecond) +} + +func getEdgeQ(ctx context.Context, q queryable, id string) (*Edge, error) { + e := &Edge{} + var validAt, invalidAt sql.NullTime + err := q.QueryRowContext(ctx, `SELECT id, from_id, to_id, type, acyclic, weight, metadata, valid_at, invalid_at, created_at FROM edges WHERE id=?`, id). + Scan(&e.ID, &e.FromID, &e.ToID, &e.Type, &e.Acyclic, &e.Weight, &e.Metadata, &validAt, &invalidAt, &e.CreatedAt) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, fmt.Errorf("%w: %s", ErrEdgeNotFound, id) + } + return nil, err + } + if validAt.Valid { + e.ValidAt = validAt.Time + } + if invalidAt.Valid { + e.InvalidAt = invalidAt.Time + } + return e, nil +} + +func (s *Store) InvalidateEdge(ctx context.Context, id string) error { + return retryOnBusy(func() error { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + return invalidateEdgeQ(ctx, s.q(), id) + }, 5, 50*time.Millisecond) +} + +func invalidateEdgeQ(ctx context.Context, q queryable, id string) error { + _, err := q.ExecContext(ctx, `UPDATE edges SET invalid_at = ? WHERE id = ?`, time.Now().UTC(), id) + return err +} + +func (s *Store) DeleteEdge(ctx context.Context, id string) error { + return retryOnBusy(func() error { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + return deleteEdgeQ(ctx, s.q(), id) + }, 5, 50*time.Millisecond) +} + +func deleteEdgeQ(ctx context.Context, q queryable, id string) error { + _, err := q.ExecContext(ctx, `DELETE FROM edges WHERE id=?`, id) + return err +} + +func (s *Store) GetEdgesFrom(ctx context.Context, nodeID string) ([]*Edge, error) { + return retryOnBusyVal(func() ([]*Edge, error) { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + return queryEdgesQ(ctx, s.q(), `SELECT id, from_id, to_id, type, acyclic, weight, metadata, valid_at, invalid_at, created_at FROM edges WHERE from_id=? AND invalid_at IS NULL`, nodeID) + }, 5, 50*time.Millisecond) +} + +func (s *Store) GetEdgesTo(ctx context.Context, nodeID string) ([]*Edge, error) { + return retryOnBusyVal(func() ([]*Edge, error) { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + return queryEdgesQ(ctx, s.q(), `SELECT id, from_id, to_id, type, acyclic, weight, metadata, valid_at, invalid_at, created_at FROM edges WHERE to_id=? AND invalid_at IS NULL`, nodeID) + }, 5, 50*time.Millisecond) +} + +// GetValidEdgesFrom returns outbound edges from nodeID that were valid at the +// given instant: valid_at <= at AND (invalid_at IS NULL OR invalid_at > at). +// A zero `at` defaults to now, giving a "currently valid" view. This is a +// point-in-time variant of GetEdgesFrom; the latter is left unchanged so +// existing callers keep their current behavior. +func (s *Store) GetValidEdgesFrom(ctx context.Context, nodeID string, at time.Time) ([]*Edge, error) { + at = validityInstant(at) + return retryOnBusyVal(func() ([]*Edge, error) { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + return queryEdgesQ(ctx, s.q(), `SELECT id, from_id, to_id, type, acyclic, weight, metadata, valid_at, invalid_at, created_at FROM edges + WHERE from_id=? AND (valid_at IS NULL OR valid_at <= ?) AND (invalid_at IS NULL OR invalid_at > ?)`, nodeID, at, at) + }, 5, 50*time.Millisecond) +} + +// GetValidEdgesTo is the inbound counterpart of GetValidEdgesFrom. +func (s *Store) GetValidEdgesTo(ctx context.Context, nodeID string, at time.Time) ([]*Edge, error) { + at = validityInstant(at) + return retryOnBusyVal(func() ([]*Edge, error) { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + return queryEdgesQ(ctx, s.q(), `SELECT id, from_id, to_id, type, acyclic, weight, metadata, valid_at, invalid_at, created_at FROM edges + WHERE to_id=? AND (valid_at IS NULL OR valid_at <= ?) AND (invalid_at IS NULL OR invalid_at > ?)`, nodeID, at, at) + }, 5, 50*time.Millisecond) +} + +// validityInstant normalizes a point-in-time argument: a zero time means "now". +func validityInstant(at time.Time) time.Time { + if at.IsZero() { + return time.Now().UTC() + } + return at.UTC() +} + +func queryEdgesQ(ctx context.Context, q queryable, query string, args ...any) ([]*Edge, error) { + rows, err := q.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + return scanEdges(rows) +} + +func (s *Store) GetEdgesBetween(ctx context.Context, nodeIDs []string) ([]*Edge, error) { + return retryOnBusyVal(func() ([]*Edge, error) { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + return getEdgesBetweenQ(ctx, s.q(), nodeIDs) + }, 5, 50*time.Millisecond) +} + +func getEdgesBetweenQ(ctx context.Context, q queryable, nodeIDs []string) ([]*Edge, error) { + if len(nodeIDs) == 0 { + return nil, nil + } + var all []*Edge + for i := 0; i < len(nodeIDs); i += maxSQLVariables { + end := i + maxSQLVariables + if end > len(nodeIDs) { + end = len(nodeIDs) + } + chunk := nodeIDs[i:end] + placeholders := make([]string, len(chunk)) + args := make([]any, 0, len(chunk)*2) + for j, id := range chunk { + placeholders[j] = "?" + args = append(args, id) + } + ph := strings.Join(placeholders, ",") + query := fmt.Sprintf(`SELECT id, from_id, to_id, type, acyclic, weight, metadata, valid_at, invalid_at, created_at FROM edges + WHERE from_id IN (%s) AND to_id IN (%s)`, ph, ph) + args = append(args, args...) + edges, err := queryEdgesQ(ctx, q, query, args...) + if err != nil { + return nil, err + } + all = append(all, edges...) + } + return all, nil +} + +// GetAllEdgesFor returns all edges where from_id OR to_id is in the given set. +// Used by IntentBFS for batch edge retrieval (avoids N+1 queries). +func (s *Store) GetAllEdgesFor(ctx context.Context, nodeIDs []string) ([]*Edge, error) { + return retryOnBusyVal(func() ([]*Edge, error) { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + return getAllEdgesForQ(ctx, s.q(), nodeIDs) + }, 5, 50*time.Millisecond) +} + +func getAllEdgesForQ(ctx context.Context, q queryable, nodeIDs []string) ([]*Edge, error) { + if len(nodeIDs) == 0 { + return nil, nil + } + var all []*Edge + for i := 0; i < len(nodeIDs); i += maxSQLVariables { + end := i + maxSQLVariables + if end > len(nodeIDs) { + end = len(nodeIDs) + } + chunk := nodeIDs[i:end] + placeholders := make([]string, len(chunk)) + args := make([]any, 0, len(chunk)*2) + for j, id := range chunk { + placeholders[j] = "?" + args = append(args, id) + } + ph := strings.Join(placeholders, ",") + query := fmt.Sprintf(`SELECT id, from_id, to_id, type, acyclic, weight, metadata, valid_at, invalid_at, created_at FROM edges + WHERE from_id IN (%s) OR to_id IN (%s)`, ph, ph) + args = append(args, args[:len(chunk)]...) // second set of args for to_id + edges, err := queryEdgesQ(ctx, q, query, args...) + if err != nil { + return nil, err + } + all = append(all, edges...) + } + return all, nil +} + +func (s *Store) GetNeighbors(ctx context.Context, nodeID string) ([]*Node, error) { + return retryOnBusyVal(func() ([]*Node, error) { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + return getNeighborsQ(ctx, s.q(), nodeID) + }, 5, 50*time.Millisecond) +} + +func getNeighborsQ(ctx context.Context, q queryable, nodeID string) ([]*Node, error) { + rows, err := q.QueryContext(ctx, `SELECT DISTINCT n.id, n.type, n.content, n.content_hash, n.summary, n.scope, n.project, n.tier, n.tags, n.key, n.pinned, n.confidence, n.access_count, n.created_at, n.updated_at, n.accessed_at, n.source_session, n.source_agent, n.version + FROM nodes n JOIN edges e ON (e.to_id = n.id AND e.from_id = ?) OR (e.from_id = n.id AND e.to_id = ?)`, nodeID, nodeID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + return scanNodes(rows) +} + +// CountEdges returns inbound and outbound edge counts for a node. +func (s *Store) CountEdges(ctx context.Context, nodeID string) (inbound int, outbound int, err error) { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + return countEdgesQ(ctx, s.q(), nodeID) +} + +func countEdgesQ(ctx context.Context, q queryable, nodeID string) (int, int, error) { + var inbound, outbound int + err := q.QueryRowContext(ctx, + `SELECT (SELECT COUNT(*) FROM edges WHERE to_id = ?), (SELECT COUNT(*) FROM edges WHERE from_id = ?)`, + nodeID, nodeID).Scan(&inbound, &outbound) + return inbound, outbound, err +} + +// CountAllEdges returns the total number of edges in the graph. +func (s *Store) CountAllEdges(ctx context.Context) (int, error) { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + return countAllEdgesQ(ctx, s.db) +} + +func countAllEdgesQ(ctx context.Context, q queryable) (int, error) { + var count int + err := q.QueryRowContext(ctx, `SELECT COUNT(*) FROM edges`).Scan(&count) + return count, err +} + +// CountEdgesBatch returns inbound/outbound edge counts for multiple nodes in a single query. +func (s *Store) CountEdgesBatch(ctx context.Context, nodeIDs []string) (map[string][2]int, error) { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + return countEdgesBatchQ(ctx, s.q(), nodeIDs) +} + +func countEdgesBatchQ(ctx context.Context, q queryable, nodeIDs []string) (map[string][2]int, error) { + if len(nodeIDs) == 0 { + return nil, nil + } + result := make(map[string][2]int, len(nodeIDs)) + for _, id := range nodeIDs { + result[id] = [2]int{0, 0} + } + + // Count outbound edges (from_id IN (...)) + outQ := `SELECT from_id, COUNT(*) FROM edges WHERE from_id IN (` + placeholders(len(nodeIDs)) + `) GROUP BY from_id` + args := make([]any, len(nodeIDs)) + for i, id := range nodeIDs { + args[i] = id + } + rows, err := q.QueryContext(ctx, outQ, args...) + if err != nil { + return nil, fmt.Errorf("count edges batch outbound: %w", err) + } + defer func() { _ = rows.Close() }() + for rows.Next() { + var id string + var count int + if err := rows.Scan(&id, &count); err != nil { + return nil, err + } + if v, ok := result[id]; ok { + v[1] = count + result[id] = v + } + } + if err := rows.Err(); err != nil { + return nil, err + } + + // Count inbound edges (to_id IN (...)) + inQ := `SELECT to_id, COUNT(*) FROM edges WHERE to_id IN (` + placeholders(len(nodeIDs)) + `) GROUP BY to_id` + rows2, err := q.QueryContext(ctx, inQ, args...) + if err != nil { + return nil, fmt.Errorf("count edges batch inbound: %w", err) + } + defer func() { _ = rows2.Close() }() + for rows2.Next() { + var id string + var count int + if err := rows2.Scan(&id, &count); err != nil { + return nil, err + } + if v, ok := result[id]; ok { + v[0] = count + result[id] = v + } + } + if err := rows2.Err(); err != nil { + return nil, err + } + + return result, nil +} + +func placeholders(n int) string { + if n <= 0 { + return "" + } + b := make([]byte, 0, n*2-1) + for i := 0; i < n; i++ { + if i > 0 { + b = append(b, ',') + } + b = append(b, '?') + } + return string(b) +} + +// NodeStats returns nodes grouped by type and total count. +func (s *Store) NodeStats(ctx context.Context) (map[string]int, int, error) { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + return nodeStatsQ(ctx, s.db) +} + +func nodeStatsQ(ctx context.Context, q queryable) (map[string]int, int, error) { + rows, err := q.QueryContext(ctx, `SELECT type, COUNT(*) FROM nodes GROUP BY type`) + if err != nil { + return nil, 0, err + } + defer func() { _ = rows.Close() }() + stats := make(map[string]int) + total := 0 + for rows.Next() { + var typ string + var cnt int + if err := rows.Scan(&typ, &cnt); err != nil { + return nil, 0, err + } + stats[typ] = cnt + total += cnt + } + return stats, total, rows.Err() +} + +// DoctorStats returns aggregated diagnostic counts via SQL without loading +// individual nodes into memory. +func (s *Store) DoctorStats(ctx context.Context) (DoctorStatsResult, error) { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + return doctorStatsQ(ctx, s.db) +} + +func doctorStatsQ(ctx context.Context, q queryable) (DoctorStatsResult, error) { + var r DoctorStatsResult + + err := q.QueryRowContext(ctx, ` + SELECT + COUNT(*), + SUM(CASE WHEN confidence < 0.2 THEN 1 ELSE 0 END), + SUM(CASE WHEN pinned = 1 THEN 1 ELSE 0 END), + SUM(CASE WHEN NOT EXISTS ( + SELECT 1 FROM edges ef WHERE ef.from_id = nodes.id + UNION ALL + SELECT 1 FROM edges et WHERE et.to_id = nodes.id + LIMIT 1 + ) THEN 1 ELSE 0 END) + FROM nodes + `).Scan(&r.TotalNodes, &r.LowConfidence, &r.Pinned, &r.Orphans) + if err != nil { + return DoctorStatsResult{}, err + } + return r, nil +} + +// LastUpdated returns the most recent updated_at time across all nodes. +func (s *Store) LastUpdated(ctx context.Context) (time.Time, error) { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + return lastUpdatedQ(ctx, s.db) +} + +func lastUpdatedQ(ctx context.Context, q queryable) (time.Time, error) { + // updated_at is stored as a text timestamp; the sqlite driver returns + // MAX(updated_at) as a string, so scan into a NullString and parse it the + // same way the rest of this file does (RFC3339Nano with a legacy fallback). + // Scanning directly into sql.NullTime fails the driver's type conversion + // whenever a row exists. + var s sql.NullString + if err := q.QueryRowContext(ctx, `SELECT MAX(updated_at) FROM nodes`).Scan(&s); err != nil { + return time.Time{}, err + } + if !s.Valid || s.String == "" { + return time.Time{}, nil + } + // The sqlite driver stores a time.Time as its String() form + // ("2006-01-02 15:04:05.999999999 -0700 MST"); also accept RFC3339 and the + // legacy "2006-01-02 15:04:05" layout for robustness across writers. + for _, layout := range []string{"2006-01-02 15:04:05.999999999 -0700 MST", time.RFC3339Nano, "2006-01-02 15:04:05"} { + if t, err := time.Parse(layout, s.String); err == nil { + return t, nil + } + } + return time.Time{}, nil +} + +// TopConnected returns the content of the most-connected nodes by edge count. +func (s *Store) TopConnected(ctx context.Context, limit int) ([]string, error) { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + return topConnectedQ(ctx, s.q(), limit) +} + +func topConnectedQ(ctx context.Context, q queryable, limit int) ([]string, error) { + rows, err := q.QueryContext(ctx, `SELECT n.content, COUNT(*) as cnt + FROM edges e JOIN nodes n ON n.id = e.from_id OR n.id = e.to_id + GROUP BY n.id ORDER BY cnt DESC LIMIT ?`, limit) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + var out []string + for rows.Next() { + var content string + var cnt int + if err := rows.Scan(&content, &cnt); err != nil { + return nil, err + } + out = append(out, content) + } + return out, rows.Err() +} + +// CheckCycle uses a recursive CTE to detect if adding from->to would create a cycle +// among acyclic edges. +func (s *Store) CheckCycle(ctx context.Context, fromID, toID string) (bool, error) { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + var result bool + err := retryOnBusy(func() error { + var innerErr error + result, innerErr = checkCycleQ(ctx, s.q(), fromID, toID) + return innerErr + }, 5, 2*time.Millisecond) + return result, err +} + +func checkCycleQ(ctx context.Context, q queryable, fromID, toID string) (bool, error) { + // UNION (not UNION ALL) deduplicates the frontier: if a cycle ever slips + // into the acyclic edge set, UNION ALL would recurse forever, whereas UNION + // terminates once every reachable ancestor has been visited. + query := ` + WITH RECURSIVE ancestors(id) AS ( + SELECT ? + UNION + SELECT e.from_id FROM ancestors a + JOIN edges e ON e.to_id = a.id AND e.acyclic = 1 + ) + SELECT 1 FROM ancestors WHERE id = ? LIMIT 1` + var exists int + err := q.QueryRowContext(ctx, query, fromID, toID).Scan(&exists) + if err == nil { + return true, nil + } + if err == sql.ErrNoRows { + return false, nil + } + return false, err +} diff --git a/storage/sqlite_nodes.go b/storage/sqlite_nodes.go new file mode 100644 index 0000000..10d9327 --- /dev/null +++ b/storage/sqlite_nodes.go @@ -0,0 +1,608 @@ +// This file is part of package storage. It holds node, version, metadata, +// signature, and access-log storage operations split verbatim out of +// sqlite.go for readability; behavior is unchanged. + +package storage + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + "time" + + "github.com/GrayCodeAI/yaad/internal/telemetry" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" +) + +// --- Nodes --- + +func (s *Store) CreateNode(ctx context.Context, n *Node) error { + start := time.Now() + ctx, cancel := s.withTimeout(ctx) + defer cancel() + err := retryOnBusy(func() error { + return createNodeQ(ctx, s.q(), n) + }, 5, 2*time.Millisecond) + attrs := attribute.NewSet(attribute.String("op", "create_node")) + telemetry.SQLiteQueryDuration.Record(ctx, time.Since(start).Seconds(), metric.WithAttributeSet(attrs)) + telemetry.SQLiteQueryCount.Add(ctx, 1, metric.WithAttributeSet(attrs)) + return err +} + +func createNodeQ(ctx context.Context, q queryable, n *Node) error { + _, err := q.ExecContext(ctx, `INSERT INTO nodes (id, type, content, content_hash, summary, scope, project, tier, tags, key, pinned, confidence, access_count, created_at, updated_at, accessed_at, source_session, source_agent, version) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + n.ID, n.Type, n.Content, n.ContentHash, n.Summary, n.Scope, n.Project, n.Tier, n.Tags, nullString(n.Key), n.Pinned, n.Confidence, n.AccessCount, + n.CreatedAt, n.UpdatedAt, nullTime(n.AccessedAt), n.SourceSession, n.SourceAgent, n.Version) + if err != nil && strings.Contains(err.Error(), "UNIQUE constraint failed") { + return fmt.Errorf("%w: %s", ErrDuplicateNode, err) + } + if err != nil { + return err + } + // Persist structured metadata. + return saveNodeMetadataQ(ctx, q, n.ID, n.Metadata) +} + +// saveNodeMetadataQ replaces all metadata for a node, inserting rows for each +// key-value pair. Uses the same queryable (pooled connection or transaction). +func saveNodeMetadataQ(ctx context.Context, q queryable, nodeID string, meta map[string]string) error { + if len(meta) == 0 { + return nil + } + // Delete existing metadata for this node (upsert semantics). + if _, err := q.ExecContext(ctx, `DELETE FROM node_metadata WHERE node_id = ?`, nodeID); err != nil { + return err + } + for k, v := range meta { + if _, err := q.ExecContext(ctx, + `INSERT INTO node_metadata (node_id, key, value) VALUES (?, ?, ?)`, nodeID, k, v); err != nil { + return err + } + } + return nil +} + +// GetNode retrieves a node by its primary key ID. +// Returns ErrNodeNotFound wrapped with the ID when no row is found. +func (s *Store) GetNode(ctx context.Context, id string) (*Node, error) { + start := time.Now() + n, err := retryOnBusyVal(func() (*Node, error) { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + return getNodeQ(ctx, s.q(), id) + }, 5, 50*time.Millisecond) + attrs := attribute.NewSet(attribute.String("op", "get_node")) + telemetry.SQLiteQueryDuration.Record(ctx, time.Since(start).Seconds(), metric.WithAttributeSet(attrs)) + telemetry.SQLiteQueryCount.Add(ctx, 1, metric.WithAttributeSet(attrs)) + _ = err // metrics always recorded + return n, err +} + +func getNodeQ(ctx context.Context, q queryable, id string) (*Node, error) { + n := &Node{} + var accessedAt sql.NullTime + var key sql.NullString + err := q.QueryRowContext(ctx, `SELECT id, type, content, content_hash, summary, scope, project, tier, tags, key, pinned, confidence, access_count, created_at, updated_at, accessed_at, source_session, source_agent, version FROM nodes WHERE id = ?`, id). + Scan(&n.ID, &n.Type, &n.Content, &n.ContentHash, &n.Summary, &n.Scope, &n.Project, &n.Tier, &n.Tags, &key, &n.Pinned, &n.Confidence, &n.AccessCount, &n.CreatedAt, &n.UpdatedAt, &accessedAt, &n.SourceSession, &n.SourceAgent, &n.Version) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, fmt.Errorf("%w: %s", ErrNodeNotFound, id) + } + return nil, err + } + if accessedAt.Valid { + n.AccessedAt = accessedAt.Time + } + if key.Valid { + n.Key = key.String + } + return n, nil +} + +// GetNodeByKey looks up a node by its unique key within a project. +// Returns (nil, nil) when no matching node is found (upsert check pattern). +func (s *Store) GetNodeByKey(ctx context.Context, key, project string) (*Node, error) { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + return getNodeByKeyQ(ctx, s.q(), key, project) +} + +func getNodeByKeyQ(ctx context.Context, q queryable, key, project string) (*Node, error) { + n := &Node{} + var accessedAt sql.NullTime + var k sql.NullString + err := q.QueryRowContext(ctx, `SELECT id, type, content, content_hash, summary, scope, project, tier, tags, key, pinned, confidence, access_count, created_at, updated_at, accessed_at, source_session, source_agent, version FROM nodes WHERE key = ? AND project = ?`, key, project). + Scan(&n.ID, &n.Type, &n.Content, &n.ContentHash, &n.Summary, &n.Scope, &n.Project, &n.Tier, &n.Tags, &k, &n.Pinned, &n.Confidence, &n.AccessCount, &n.CreatedAt, &n.UpdatedAt, &accessedAt, &n.SourceSession, &n.SourceAgent, &n.Version) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err + } + if accessedAt.Valid { + n.AccessedAt = accessedAt.Time + } + if k.Valid { + n.Key = k.String + } + return n, nil +} + +func (s *Store) UpdateNode(ctx context.Context, n *Node) error { + start := time.Now() + err := retryOnBusy(func() error { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + return updateNodeQ(ctx, s.q(), n) + }, 5, 50*time.Millisecond) + attrs := attribute.NewSet(attribute.String("op", "update_node")) + telemetry.SQLiteQueryDuration.Record(ctx, time.Since(start).Seconds(), metric.WithAttributeSet(attrs)) + telemetry.SQLiteQueryCount.Add(ctx, 1, metric.WithAttributeSet(attrs)) + return err +} + +func updateNodeQ(ctx context.Context, q queryable, n *Node) error { + _, err := q.ExecContext(ctx, `UPDATE nodes SET type=?, content=?, content_hash=?, summary=?, scope=?, project=?, tier=?, tags=?, key=?, pinned=?, confidence=?, access_count=?, updated_at=?, accessed_at=?, source_session=?, source_agent=?, version=? WHERE id=?`, + n.Type, n.Content, n.ContentHash, n.Summary, n.Scope, n.Project, n.Tier, n.Tags, nullString(n.Key), n.Pinned, n.Confidence, n.AccessCount, + n.UpdatedAt, nullTime(n.AccessedAt), n.SourceSession, n.SourceAgent, n.Version, n.ID) + if err != nil { + return err + } + // Persist structured metadata (delete + insert). + return saveNodeMetadataQ(ctx, q, n.ID, n.Metadata) +} + +func (s *Store) UpdateNodeContent(ctx context.Context, id, newContent string) error { + return retryOnBusy(func() error { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + return updateNodeContentQ(ctx, s.q(), id, newContent) + }, 5, 50*time.Millisecond) +} + +func updateNodeContentQ(ctx context.Context, q queryable, id, newContent string) error { + _, err := q.ExecContext(ctx, `UPDATE nodes SET content=?, updated_at=CURRENT_TIMESTAMP WHERE id=?`, newContent, id) + return err +} + +func (s *Store) DeleteNode(ctx context.Context, id string) error { + return retryOnBusy(func() error { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer func() { _ = tx.Rollback() }() + if err := deleteNodeQ(ctx, tx, id); err != nil { + return err + } + return tx.Commit() + }, 5, 50*time.Millisecond) +} + +func deleteNodeQ(ctx context.Context, q queryable, id string) error { + if _, err := q.ExecContext(ctx, `DELETE FROM edges WHERE from_id=? OR to_id=?`, id, id); err != nil { + return err + } + _, err := q.ExecContext(ctx, `DELETE FROM nodes WHERE id=?`, id) + return err +} + +func (s *Store) ListNodes(ctx context.Context, f NodeFilter) ([]*Node, error) { + start := time.Now() + nodes, err := retryOnBusyVal(func() ([]*Node, error) { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + return listNodesQ(ctx, s.q(), f) + }, 5, 50*time.Millisecond) + attrs := attribute.NewSet(attribute.String("op", "list_nodes")) + telemetry.SQLiteQueryDuration.Record(ctx, time.Since(start).Seconds(), metric.WithAttributeSet(attrs)) + telemetry.SQLiteQueryCount.Add(ctx, 1, metric.WithAttributeSet(attrs)) + return nodes, err +} + +func listNodesQ(ctx context.Context, q queryable, f NodeFilter) ([]*Node, error) { + query := "SELECT id, type, content, content_hash, summary, scope, project, tier, tags, key, pinned, confidence, access_count, created_at, updated_at, accessed_at, source_session, source_agent, version FROM nodes WHERE 1=1" + var args []any + if f.Type != "" { + query += " AND type=?" + args = append(args, f.Type) + } + if f.Scope != "" { + query += " AND scope=?" + args = append(args, f.Scope) + } + if f.Project != "" { + query += " AND project=?" + args = append(args, f.Project) + } + if f.Tier > 0 { + query += " AND tier=?" + args = append(args, f.Tier) + } + if f.MinConfidence > 0 { + query += " AND confidence>=?" + args = append(args, f.MinConfidence) + } + if f.SourceSession != "" { + query += " AND source_session=?" + args = append(args, f.SourceSession) + } + if f.Pinned != nil { + query += " AND pinned=?" + args = append(args, *f.Pinned) + } + if f.Tag != "" { + // Delimiter-aware exact tag match: wrap both the stored CSV and the + // target in commas so "topic:foo" does not match "topic:foobar". + // '%' and '_' in the tag are escaped so they are treated literally. + esc := strings.NewReplacer(`\`, `\\`, `%`, `\%`, `_`, `\_`).Replace(f.Tag) + query += ` AND (',' || tags || ',') LIKE ? ESCAPE '\'` + args = append(args, "%,"+esc+",%") + } + // Metadata key-value filters (AND semantics — all must match). + for k, v := range f.MetadataFilters { + query += ` AND EXISTS (SELECT 1 FROM node_metadata nm WHERE nm.node_id = nodes.id AND nm.key = ? AND nm.value = ?)` + args = append(args, k, v) + } + query += " LIMIT ?" + if f.Limit > 0 { + args = append(args, f.Limit) + } else { + args = append(args, 1000) + } + if f.Offset > 0 { + query += " OFFSET ?" + args = append(args, f.Offset) + } + rows, err := q.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + return scanNodes(rows) +} + +// escapeFTS5 escapes special FTS5 characters by wrapping each token in double +// quotes and escaping embedded quotes. This prevents FTS5 query injection via +// operators like *, -, AND, OR, NOT. +func escapeFTS5(query string) string { + words := strings.Fields(query) + for i, w := range words { + // Escape embedded quotes by doubling them, then wrap in quotes + w = strings.ReplaceAll(w, `"`, `""`) + words[i] = `"` + w + `"` + } + return strings.Join(words, " OR ") +} + +func (s *Store) SearchNodes(ctx context.Context, query string, limit int) ([]*Node, error) { + start := time.Now() + nodes, err := retryOnBusyVal(func() ([]*Node, error) { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + return searchNodesQ(ctx, s.q(), query, limit) + }, 5, 50*time.Millisecond) + attrs := attribute.NewSet(attribute.String("op", "search_nodes")) + telemetry.SQLiteQueryDuration.Record(ctx, time.Since(start).Seconds(), metric.WithAttributeSet(attrs)) + telemetry.SQLiteQueryCount.Add(ctx, 1, metric.WithAttributeSet(attrs)) + return nodes, err +} + +func searchNodesQ(ctx context.Context, q queryable, query string, limit int) ([]*Node, error) { + if limit <= 0 { + limit = 10 + } + ftsQuery := escapeFTS5(query) + rows, err := q.QueryContext(ctx, `SELECT n.id, n.type, n.content, n.content_hash, n.summary, n.scope, n.project, n.tier, n.tags, n.key, n.pinned, n.confidence, n.access_count, n.created_at, n.updated_at, n.accessed_at, n.source_session, n.source_agent, n.version + FROM nodes_fts f JOIN nodes n ON f.rowid = n.rowid WHERE nodes_fts MATCH ? ORDER BY rank LIMIT ?`, ftsQuery, limit) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + return scanNodes(rows) +} + +// --- Versions --- + +func (s *Store) SaveVersion(ctx context.Context, nodeID string, content, changedBy, reason string) error { + return retryOnBusy(func() error { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer func() { _ = tx.Rollback() }() + if err := saveVersionQ(ctx, tx, nodeID, content, changedBy, reason); err != nil { + return err + } + return tx.Commit() + }, 5, 50*time.Millisecond) +} + +func saveVersionQ(ctx context.Context, q queryable, nodeID string, content, changedBy, reason string) error { + var maxVer int + err := q.QueryRowContext(ctx, `SELECT COALESCE(MAX(version), 0) FROM node_versions WHERE node_id=?`, nodeID).Scan(&maxVer) + if err != nil { + return err + } + _, err = q.ExecContext(ctx, `INSERT INTO node_versions (node_id, version, content, changed_at, changed_by, reason) VALUES (?, ?, ?, ?, ?, ?)`, + nodeID, maxVer+1, content, time.Now().UTC(), changedBy, reason) + return err +} + +func (s *Store) GetVersions(ctx context.Context, nodeID string) ([]*NodeVersion, error) { + return retryOnBusyVal(func() ([]*NodeVersion, error) { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + return getVersionsQ(ctx, s.q(), nodeID) + }, 5, 50*time.Millisecond) +} + +func getVersionsQ(ctx context.Context, q queryable, nodeID string) ([]*NodeVersion, error) { + rows, err := q.QueryContext(ctx, `SELECT node_id, version, content, changed_at, changed_by, reason FROM node_versions WHERE node_id=? ORDER BY version`, nodeID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + var out []*NodeVersion + for rows.Next() { + v := &NodeVersion{} + if err := rows.Scan(&v.NodeID, &v.Version, &v.Content, &v.ChangedAt, &v.ChangedBy, &v.Reason); err != nil { + return nil, err + } + out = append(out, v) + } + return out, rows.Err() +} + +// --- AccessLog --- + +// LogAccess records a lightweight access event (INSERT only, no UPDATE churn). +func (s *Store) LogAccess(ctx context.Context, nodeID string) error { + return retryOnBusy(func() error { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + return logAccessQ(ctx, s.q(), nodeID) + }, 5, 50*time.Millisecond) +} + +func logAccessQ(ctx context.Context, q queryable, nodeID string) error { + _, err := q.ExecContext(ctx, `INSERT INTO access_log (node_id) VALUES (?)`, nodeID) + return err +} + +// FlushAccessLog aggregates access_log entries into nodes.access_count / accessed_at, +// then truncates the log. Runs atomically within a transaction. +func (s *Store) FlushAccessLog(ctx context.Context) (int, error) { + return retryOnBusyVal(func() (int, error) { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return 0, err + } + defer func() { _ = tx.Rollback() }() + n, err := flushAccessLogQ(ctx, tx) + if err != nil { + return 0, err + } + return n, tx.Commit() + }, 5, 50*time.Millisecond) +} + +func flushAccessLogQ(ctx context.Context, q queryable) (int, error) { + rows, err := q.QueryContext(ctx, ` + SELECT node_id, COUNT(*) as cnt, MAX(created_at) as last_at + FROM access_log + GROUP BY node_id`) + if err != nil { + return 0, err + } + defer func() { _ = rows.Close() }() + type agg struct { + nodeID string + count int + lastAt time.Time + } + var aggs []agg + for rows.Next() { + var a agg + var lastAtStr string + if err := rows.Scan(&a.nodeID, &a.count, &lastAtStr); err != nil { + return 0, err + } + if lastAtStr != "" { + t, _ := time.Parse(time.RFC3339Nano, lastAtStr) + if t.IsZero() { + t, _ = time.Parse("2006-01-02 15:04:05", lastAtStr) + } + a.lastAt = t + } + aggs = append(aggs, a) + } + if err := rows.Err(); err != nil { + return 0, err + } + if len(aggs) == 0 { + return 0, nil + } + for _, a := range aggs { + _, err := q.ExecContext(ctx, ` + UPDATE nodes + SET access_count = access_count + ?, + accessed_at = MAX(COALESCE(accessed_at, '1970-01-01'), ?) + WHERE id = ?`, + a.count, a.lastAt, a.nodeID) + if err != nil { + return 0, err + } + } + _, err = q.ExecContext(ctx, `DELETE FROM access_log`) + if err != nil { + return 0, err + } + return len(aggs), nil +} + +// --- Metadata --- + +func (s *Store) SaveNodeMetadata(ctx context.Context, nodeID string, meta map[string]string) error { + return retryOnBusy(func() error { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + return saveNodeMetadataQ(ctx, s.q(), nodeID, meta) + }, 5, 50*time.Millisecond) +} + +func (s *Store) LoadNodeMetadata(ctx context.Context, nodeIDs []string) (map[string]map[string]string, error) { + if len(nodeIDs) == 0 { + return nil, nil + } + ids, args := chunkedArgs(nodeIDs) + return retryOnBusyVal(func() (map[string]map[string]string, error) { + return loadNodeMetadataQ(ctx, s.q(), ids, args) + }, 2, 10*time.Millisecond) +} + +func loadNodeMetadataQ(ctx context.Context, q queryable, idsChunk []string, args []any) (map[string]map[string]string, error) { + query := "SELECT node_id, key, value FROM node_metadata WHERE node_id IN (?" + for i := 1; i < len(idsChunk); i++ { + query += ", ?" + } + query += ")" + rows, err := q.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + result := make(map[string]map[string]string) + for rows.Next() { + var nodeID, key, value string + if err := rows.Scan(&nodeID, &key, &value); err != nil { + return nil, err + } + if result[nodeID] == nil { + result[nodeID] = make(map[string]string) + } + result[nodeID][key] = value + } + return result, rows.Err() +} + +func chunkedArgs(ids []string) ([]string, []any) { + args := make([]any, len(ids)) + for i, id := range ids { + args[i] = id + } + return ids, args +} + +// FillNodeMetadata loads metadata for all given nodes in one batch query and +// populates each node's Metadata field in-place. No-op for an empty slice. +func (s *Store) FillNodeMetadata(ctx context.Context, nodes []*Node) error { + if len(nodes) == 0 { + return nil + } + ids := make([]string, len(nodes)) + for i, n := range nodes { + ids[i] = n.ID + } + meta, err := s.LoadNodeMetadata(ctx, ids) + if err != nil { + return err + } + for _, n := range nodes { + n.Metadata = meta[n.ID] + } + return nil +} + +// --- Signatures --- + +func (s *Store) SaveSignature(ctx context.Context, nodeID, signature string) error { + return retryOnBusy(func() error { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + return saveSignatureQ(ctx, s.q(), nodeID, signature) + }, 5, 50*time.Millisecond) +} + +func saveSignatureQ(ctx context.Context, q queryable, nodeID, signature string) error { + _, err := q.ExecContext(ctx, + `INSERT OR REPLACE INTO node_signatures (node_id, signature, signed_at) VALUES (?, ?, CURRENT_TIMESTAMP)`, + nodeID, signature) + return err +} + +func (s *Store) GetAllSignatures(ctx context.Context) (map[string]string, error) { + return retryOnBusyVal(func() (map[string]string, error) { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + return getAllSignaturesQ(ctx, s.db) + }, 5, 50*time.Millisecond) +} + +func getAllSignaturesQ(ctx context.Context, q queryable) (map[string]string, error) { + rows, err := q.QueryContext(ctx, `SELECT node_id, signature FROM node_signatures`) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + out := make(map[string]string) + for rows.Next() { + var nodeID, sig string + if err := rows.Scan(&nodeID, &sig); err != nil { + return nil, err + } + out[nodeID] = sig + } + return out, rows.Err() +} + +func (s *Store) GetNodesBatch(ctx context.Context, ids []string) ([]*Node, error) { + return retryOnBusyVal(func() ([]*Node, error) { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + return getNodesBatchQ(ctx, s.q(), ids) + }, 5, 50*time.Millisecond) +} + +func getNodesBatchQ(ctx context.Context, q queryable, ids []string) ([]*Node, error) { + if len(ids) == 0 { + return nil, nil + } + var all []*Node + for i := 0; i < len(ids); i += maxSQLVariables { + end := i + maxSQLVariables + if end > len(ids) { + end = len(ids) + } + chunk := ids[i:end] + placeholders := make([]string, len(chunk)) + args := make([]any, len(chunk)) + for j, id := range chunk { + placeholders[j] = "?" + args[j] = id + } + query := fmt.Sprintf(`SELECT id, type, content, content_hash, summary, scope, project, tier, tags, key, pinned, confidence, access_count, created_at, updated_at, accessed_at, source_session, source_agent, version FROM nodes WHERE id IN (%s)`, strings.Join(placeholders, ",")) + rows, err := q.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + nodes, err := scanNodes(rows) + _ = rows.Close() + if err != nil { + return nil, err + } + all = append(all, nodes...) + } + return all, nil +} diff --git a/storage/sqlite_tx.go b/storage/sqlite_tx.go new file mode 100644 index 0000000..fd0efbb --- /dev/null +++ b/storage/sqlite_tx.go @@ -0,0 +1,268 @@ +// This file is part of package storage. It holds the transactional store +// (WithTx and txStore) plus version rollback/diff, split verbatim out of +// sqlite.go for readability; behavior is unchanged. + +package storage + +import ( + "context" + "database/sql" + "fmt" + "time" +) + +// WithTx runs the given function inside a SQL transaction. +// If the function returns an error, the transaction is rolled back. +// +// The whole transaction is retried on SQLITE_BUSY. A transaction acquires the +// single SQLite writer more eagerly than a lone statement, so it can lose a +// race with the async ingestion goroutine; without retry, correctness-critical +// transactional writes (e.g. the atomic cycle-check+insert in AddEdge) would +// fail spuriously under concurrency. Re-running fn is safe because a busy error +// occurs before commit, so no partial state is visible. +func (s *Store) WithTx(ctx context.Context, fn func(Storage) error) error { + return retryOnBusy(func() error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer func() { _ = tx.Rollback() }() + + txStore := &txStore{tx: tx} + if err := fn(txStore); err != nil { + return err + } + return tx.Commit() + }, 5, 50*time.Millisecond) +} + +// txStore is a Storage implementation backed by a SQL transaction. +type txStore struct { + tx *sql.Tx +} + +// txStore is a thin wrapper that delegates all operations to shared *Q functions. +func (t *txStore) CreateNode(ctx context.Context, n *Node) error { return createNodeQ(ctx, t.tx, n) } + +func (t *txStore) GetNode(ctx context.Context, id string) (*Node, error) { + return getNodeQ(ctx, t.tx, id) +} + +func (t *txStore) GetNodeByKey(ctx context.Context, key, project string) (*Node, error) { + return getNodeByKeyQ(ctx, t.tx, key, project) +} + +func (t *txStore) GetNodesBatch(ctx context.Context, ids []string) ([]*Node, error) { + return getNodesBatchQ(ctx, t.tx, ids) +} + +func (t *txStore) UpdateNode(ctx context.Context, n *Node) error { return updateNodeQ(ctx, t.tx, n) } + +func (t *txStore) UpdateNodeContent(ctx context.Context, id, newContent string) error { + return updateNodeContentQ(ctx, t.tx, id, newContent) +} + +func (t *txStore) DeleteNode(ctx context.Context, id string) error { return deleteNodeQ(ctx, t.tx, id) } + +func (t *txStore) ListNodes(ctx context.Context, f NodeFilter) ([]*Node, error) { + return listNodesQ(ctx, t.tx, f) +} + +func (t *txStore) SearchNodes(ctx context.Context, query string, limit int) ([]*Node, error) { + return searchNodesQ(ctx, t.tx, query, limit) +} + +func (t *txStore) SearchNodeByHash(ctx context.Context, hash, scope, project string) (*Node, error) { + return searchNodeByHashQ(ctx, t.tx, hash, scope, project) +} + +func (t *txStore) GetNeighbors(ctx context.Context, nodeID string) ([]*Node, error) { + return getNeighborsQ(ctx, t.tx, nodeID) +} + +func (t *txStore) CreateEdge(ctx context.Context, e *Edge) error { return createEdgeQ(ctx, t.tx, e) } + +func (t *txStore) GetEdge(ctx context.Context, id string) (*Edge, error) { + return getEdgeQ(ctx, t.tx, id) +} + +func (t *txStore) InvalidateEdge(ctx context.Context, id string) error { + return invalidateEdgeQ(ctx, t.tx, id) +} + +func (t *txStore) DeleteEdge(ctx context.Context, id string) error { return deleteEdgeQ(ctx, t.tx, id) } + +func (t *txStore) GetEdgesFrom(ctx context.Context, nodeID string) ([]*Edge, error) { + return queryEdgesQ(ctx, t.tx, `SELECT id, from_id, to_id, type, acyclic, weight, metadata, valid_at, invalid_at, created_at FROM edges WHERE from_id=? AND invalid_at IS NULL`, nodeID) +} + +func (t *txStore) GetEdgesTo(ctx context.Context, nodeID string) ([]*Edge, error) { + return queryEdgesQ(ctx, t.tx, `SELECT id, from_id, to_id, type, acyclic, weight, metadata, valid_at, invalid_at, created_at FROM edges WHERE to_id=? AND invalid_at IS NULL`, nodeID) +} + +func (t *txStore) GetValidEdgesFrom(ctx context.Context, nodeID string, at time.Time) ([]*Edge, error) { + at = validityInstant(at) + return queryEdgesQ(ctx, t.tx, `SELECT id, from_id, to_id, type, acyclic, weight, metadata, valid_at, invalid_at, created_at FROM edges + WHERE from_id=? AND (valid_at IS NULL OR valid_at <= ?) AND (invalid_at IS NULL OR invalid_at > ?)`, nodeID, at, at) +} + +func (t *txStore) GetValidEdgesTo(ctx context.Context, nodeID string, at time.Time) ([]*Edge, error) { + at = validityInstant(at) + return queryEdgesQ(ctx, t.tx, `SELECT id, from_id, to_id, type, acyclic, weight, metadata, valid_at, invalid_at, created_at FROM edges + WHERE to_id=? AND (valid_at IS NULL OR valid_at <= ?) AND (invalid_at IS NULL OR invalid_at > ?)`, nodeID, at, at) +} + +func (t *txStore) GetEdgesBetween(ctx context.Context, nodeIDs []string) ([]*Edge, error) { + return getEdgesBetweenQ(ctx, t.tx, nodeIDs) +} + +func (t *txStore) GetAllEdgesFor(ctx context.Context, nodeIDs []string) ([]*Edge, error) { + return getAllEdgesForQ(ctx, t.tx, nodeIDs) +} + +func (t *txStore) CountEdges(ctx context.Context, nodeID string) (int, int, error) { + return countEdgesQ(ctx, t.tx, nodeID) +} + +func (t *txStore) CountAllEdges(ctx context.Context) (int, error) { return countAllEdgesQ(ctx, t.tx) } + +func (t *txStore) CountEdgesBatch(ctx context.Context, nodeIDs []string) (map[string][2]int, error) { + return countEdgesBatchQ(ctx, t.tx, nodeIDs) +} + +func (t *txStore) NodeStats(ctx context.Context) (map[string]int, int, error) { + return nodeStatsQ(ctx, t.tx) +} + +func (t *txStore) DoctorStats(ctx context.Context) (DoctorStatsResult, error) { + return doctorStatsQ(ctx, t.tx) +} + +func (t *txStore) LastUpdated(ctx context.Context) (time.Time, error) { return lastUpdatedQ(ctx, t.tx) } + +func (t *txStore) TopConnected(ctx context.Context, limit int) ([]string, error) { + return topConnectedQ(ctx, t.tx, limit) +} + +func (t *txStore) CheckCycle(ctx context.Context, fromID, toID string) (bool, error) { + return checkCycleQ(ctx, t.tx, fromID, toID) +} + +func (t *txStore) CreateSession(ctx context.Context, sess *Session) error { + return createSessionQ(ctx, t.tx, sess) +} + +func (t *txStore) EndSession(ctx context.Context, id string, summary string) error { + return endSessionQ(ctx, t.tx, id, summary) +} + +func (t *txStore) ListSessions(ctx context.Context, project string, limit int) ([]*Session, error) { + return listSessionsQ(ctx, t.tx, project, limit) +} + +func (t *txStore) SaveVersion(ctx context.Context, nodeID string, content, changedBy, reason string) error { + return saveVersionQ(ctx, t.tx, nodeID, content, changedBy, reason) +} + +func (t *txStore) GetVersions(ctx context.Context, nodeID string) ([]*NodeVersion, error) { + return getVersionsQ(ctx, t.tx, nodeID) +} + +func (t *txStore) SaveEmbedding(ctx context.Context, nodeID, model string, vector []float32) error { + return saveEmbeddingQ(ctx, t.tx, nodeID, model, vector) +} + +func (t *txStore) DeleteEmbedding(ctx context.Context, nodeID string) error { + return deleteEmbeddingQ(ctx, t.tx, nodeID) +} + +func (t *txStore) GetEmbedding(ctx context.Context, nodeID string) ([]float32, string, error) { + return getEmbeddingQ(ctx, t.tx, nodeID) +} + +func (t *txStore) AllEmbeddings(ctx context.Context, model string) (map[string][]float32, error) { + return allEmbeddingsQ(ctx, t.tx, model) +} + +func (t *txStore) GetEmbeddingsBatch(ctx context.Context, model string, offset, limit int) (map[string][]float32, error) { + return getEmbeddingsBatchQ(ctx, t.tx, model, offset, limit) +} + +func (t *txStore) AddFileWatch(ctx context.Context, filePath, nodeID, gitHash string) error { + return addFileWatchQ(ctx, t.tx, filePath, nodeID, gitHash) +} + +func (t *txStore) AddReplayEvent(ctx context.Context, sessionID, data string) error { + return addReplayEventQ(ctx, t.tx, sessionID, data) +} + +func (t *txStore) GetReplayEvents(ctx context.Context, sessionID string) ([]*ReplayEvent, error) { + return getReplayEventsQ(ctx, t.tx, sessionID) +} + +func (t *txStore) LogAccess(ctx context.Context, nodeID string) error { + return logAccessQ(ctx, t.tx, nodeID) +} + +func (t *txStore) SaveNodeMetadata(ctx context.Context, nodeID string, meta map[string]string) error { + return saveNodeMetadataQ(ctx, t.tx, nodeID, meta) +} + +func (t *txStore) LoadNodeMetadata(ctx context.Context, nodeIDs []string) (map[string]map[string]string, error) { + ids, args := chunkedArgs(nodeIDs) + return loadNodeMetadataQ(ctx, t.tx, ids, args) +} + +func (t *txStore) SaveSignature(ctx context.Context, nodeID, signature string) error { + return saveSignatureQ(ctx, t.tx, nodeID, signature) +} + +func (t *txStore) GetAllSignatures(ctx context.Context) (map[string]string, error) { + return getAllSignaturesQ(ctx, t.tx) +} + +func (t *txStore) FlushAccessLog(ctx context.Context) (int, error) { return flushAccessLogQ(ctx, t.tx) } +func (t *txStore) WithTx(ctx context.Context, fn func(Storage) error) error { return fn(t) } +func (t *txStore) Close() error { return nil } + +// RollbackToVersion restores a node's content to a specific version. +// All operations run inside a single transaction to prevent concurrent +// modifications from interleaving between the SELECT and UPDATE. +func (s *Store) RollbackToVersion(ctx context.Context, nodeID string, version int) error { + return s.WithTx(ctx, func(tx Storage) error { + // GetVersions is not on Storage interface, so use GetNode + version query via txStore. + // We need the raw version content. Use the versions listing through txStore. + var content string + // txStore implements Storage; cast to access version query + if tts, ok := tx.(*txStore); ok { + err := tts.tx.QueryRowContext(ctx, + `SELECT content FROM node_versions WHERE node_id=? AND version=?`, nodeID, version).Scan(&content) + if err != nil { + return fmt.Errorf("version %d not found for node %s: %w", version, nodeID, err) + } + } else { + return fmt.Errorf("unexpected storage type in transaction") + } + if err := tx.UpdateNodeContent(ctx, nodeID, content); err != nil { + return err + } + return tx.SaveVersion(ctx, nodeID, content, "system", fmt.Sprintf("rollback to v%d", version)) + }) +} + +// DiffVersions returns the content of two versions for comparison. +func (s *Store) DiffVersions(ctx context.Context, nodeID string, v1, v2 int) (content1, content2 string, err error) { + ctx, cancel := s.withTimeout(ctx) + defer cancel() + err = s.db.QueryRowContext(ctx, + `SELECT content FROM node_versions WHERE node_id=? AND version=?`, nodeID, v1).Scan(&content1) + if err != nil { + return "", "", fmt.Errorf("version %d not found: %w", v1, err) + } + err = s.db.QueryRowContext(ctx, + `SELECT content FROM node_versions WHERE node_id=? AND version=?`, nodeID, v2).Scan(&content2) + if err != nil { + return "", "", fmt.Errorf("version %d not found: %w", v2, err) + } + return content1, content2, nil +}