Skip to content

Commit 7909f13

Browse files
committed
feat: add multi-hop bidirectional context traversal
Implement depth-aware graph traversal for GetContext using a bidirectional recursive CTE in PostgreSQL. Previously, GetContext only returned 1-hop neighbors and ignored the depth parameter. - Add GetBidirectional to EdgeRepository interface - Implement bidirectional recursive CTE with cycle detection via path array, DISTINCT dedup, and LIMIT 1000 safety cap - Default depth=1 for backward compatibility, cap at 5 - Add 5 unit tests covering multi-hop, cycles, depth capping Closes #236
1 parent 70cf58f commit 7909f13

4 files changed

Lines changed: 273 additions & 14 deletions

File tree

core/entity/edge.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ type EdgeRepository interface {
3434
GetByTarget(ctx context.Context, ns *namespace.Namespace, urn string, filter EdgeFilter) ([]Edge, error)
3535
GetDownstream(ctx context.Context, ns *namespace.Namespace, urn string, depth int) ([]Edge, error)
3636
GetUpstream(ctx context.Context, ns *namespace.Namespace, urn string, depth int) ([]Edge, error)
37+
GetBidirectional(ctx context.Context, ns *namespace.Namespace, urn string, depth int) ([]Edge, error)
3738
Delete(ctx context.Context, ns *namespace.Namespace, sourceURN, targetURN, edgeType string) error
3839
DeleteByURN(ctx context.Context, ns *namespace.Namespace, urn string) error
3940
}

core/entity/service.go

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ func (s *Service) Suggest(ctx context.Context, ns *namespace.Namespace, text str
124124
return nil, nil
125125
}
126126

127+
// maxContextDepth caps the maximum traversal depth for context queries.
128+
const maxContextDepth = 5
129+
127130
// GetContext assembles a context subgraph around an entity.
128131
func (s *Service) GetContext(ctx context.Context, ns *namespace.Namespace, urn string, depth int) (*ContextGraph, error) {
129132
ent, err := s.repo.GetByURN(ctx, ns, urn)
@@ -134,23 +137,28 @@ func (s *Service) GetContext(ctx context.Context, ns *namespace.Namespace, urn s
134137
cg := &ContextGraph{Entity: ent}
135138

136139
if s.edges != nil {
137-
_ = depth // TODO: use depth for multi-hop traversal
138-
outgoing, _ := s.edges.GetBySource(ctx, ns, urn, EdgeFilter{Current: true})
139-
incoming, _ := s.edges.GetByTarget(ctx, ns, urn, EdgeFilter{Current: true})
140-
cg.Edges = append(outgoing, incoming...)
140+
if depth <= 0 {
141+
depth = 1
142+
}
143+
if depth > maxContextDepth {
144+
depth = maxContextDepth
145+
}
146+
147+
cg.Edges, err = s.edges.GetBidirectional(ctx, ns, urn, depth)
148+
if err != nil {
149+
return nil, fmt.Errorf("get context edges: %w", err)
150+
}
141151

142152
seen := map[string]bool{urn: true}
143153
for _, e := range cg.Edges {
144-
relURN := e.TargetURN
145-
if relURN == urn {
146-
relURN = e.SourceURN
147-
}
148-
if seen[relURN] {
149-
continue
150-
}
151-
seen[relURN] = true
152-
if rel, err := s.repo.GetByURN(ctx, ns, relURN); err == nil {
153-
cg.Related = append(cg.Related, rel)
154+
for _, candidate := range []string{e.SourceURN, e.TargetURN} {
155+
if seen[candidate] {
156+
continue
157+
}
158+
seen[candidate] = true
159+
if rel, err := s.repo.GetByURN(ctx, ns, candidate); err == nil {
160+
cg.Related = append(cg.Related, rel)
161+
}
154162
}
155163
}
156164
}

core/entity/service_test.go

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,221 @@ func TestService_GetTypes(t *testing.T) {
166166
}
167167
}
168168

169+
// mockEdgeRepo is a simple in-memory edge repository for testing.
170+
type mockEdgeRepo struct {
171+
edges []Edge
172+
}
173+
174+
func (m *mockEdgeRepo) Upsert(_ context.Context, _ *namespace.Namespace, e *Edge) error {
175+
m.edges = append(m.edges, *e)
176+
return nil
177+
}
178+
179+
func (m *mockEdgeRepo) GetBySource(_ context.Context, _ *namespace.Namespace, urn string, _ EdgeFilter) ([]Edge, error) {
180+
var result []Edge
181+
for _, e := range m.edges {
182+
if e.SourceURN == urn {
183+
result = append(result, e)
184+
}
185+
}
186+
return result, nil
187+
}
188+
189+
func (m *mockEdgeRepo) GetByTarget(_ context.Context, _ *namespace.Namespace, urn string, _ EdgeFilter) ([]Edge, error) {
190+
var result []Edge
191+
for _, e := range m.edges {
192+
if e.TargetURN == urn {
193+
result = append(result, e)
194+
}
195+
}
196+
return result, nil
197+
}
198+
199+
func (m *mockEdgeRepo) GetDownstream(_ context.Context, _ *namespace.Namespace, _ string, _ int) ([]Edge, error) {
200+
return nil, nil
201+
}
202+
203+
func (m *mockEdgeRepo) GetUpstream(_ context.Context, _ *namespace.Namespace, _ string, _ int) ([]Edge, error) {
204+
return nil, nil
205+
}
206+
207+
func (m *mockEdgeRepo) GetBidirectional(_ context.Context, _ *namespace.Namespace, urn string, depth int) ([]Edge, error) {
208+
// BFS traversal up to depth hops in both directions.
209+
type frontier struct {
210+
urn string
211+
level int
212+
}
213+
visited := map[string]bool{urn: true}
214+
queue := []frontier{{urn: urn, level: 0}}
215+
var result []Edge
216+
seen := map[string]bool{} // dedup edges by source+target+type
217+
218+
for len(queue) > 0 {
219+
cur := queue[0]
220+
queue = queue[1:]
221+
if cur.level >= depth {
222+
continue
223+
}
224+
for _, e := range m.edges {
225+
key := e.SourceURN + "|" + e.TargetURN + "|" + e.Type
226+
var neighbor string
227+
if e.SourceURN == cur.urn {
228+
neighbor = e.TargetURN
229+
} else if e.TargetURN == cur.urn {
230+
neighbor = e.SourceURN
231+
} else {
232+
continue
233+
}
234+
if !seen[key] {
235+
seen[key] = true
236+
result = append(result, e)
237+
}
238+
if !visited[neighbor] {
239+
visited[neighbor] = true
240+
queue = append(queue, frontier{urn: neighbor, level: cur.level + 1})
241+
}
242+
}
243+
}
244+
return result, nil
245+
}
246+
247+
func (m *mockEdgeRepo) Delete(_ context.Context, _ *namespace.Namespace, _, _, _ string) error {
248+
return nil
249+
}
250+
251+
func (m *mockEdgeRepo) DeleteByURN(_ context.Context, _ *namespace.Namespace, _ string) error {
252+
return nil
253+
}
254+
255+
func TestService_GetContext_DefaultDepth(t *testing.T) {
256+
repo := newMockRepo()
257+
edges := &mockEdgeRepo{}
258+
svc := NewService(repo, edges, nil)
259+
ctx := context.Background()
260+
ns := namespace.DefaultNamespace
261+
262+
// A -> B -> C (linear chain)
263+
_, _ = svc.Upsert(ctx, ns, &Entity{URN: "urn:a", Type: TypeTable, Name: "a"})
264+
_, _ = svc.Upsert(ctx, ns, &Entity{URN: "urn:b", Type: TypeTable, Name: "b"})
265+
_, _ = svc.Upsert(ctx, ns, &Entity{URN: "urn:c", Type: TypeTable, Name: "c"})
266+
edges.edges = []Edge{
267+
{SourceURN: "urn:a", TargetURN: "urn:b", Type: "lineage"},
268+
{SourceURN: "urn:b", TargetURN: "urn:c", Type: "lineage"},
269+
}
270+
271+
// depth=0 should default to 1 (only direct neighbors of B)
272+
cg, err := svc.GetContext(ctx, ns, "urn:b", 0)
273+
if err != nil {
274+
t.Fatalf("GetContext failed: %v", err)
275+
}
276+
if len(cg.Edges) != 2 {
277+
t.Errorf("expected 2 edges at depth 1, got %d", len(cg.Edges))
278+
}
279+
if len(cg.Related) != 2 {
280+
t.Errorf("expected 2 related entities, got %d", len(cg.Related))
281+
}
282+
}
283+
284+
func TestService_GetContext_MultiHop(t *testing.T) {
285+
repo := newMockRepo()
286+
edges := &mockEdgeRepo{}
287+
svc := NewService(repo, edges, nil)
288+
ctx := context.Background()
289+
ns := namespace.DefaultNamespace
290+
291+
// A -> B -> C -> D
292+
_, _ = svc.Upsert(ctx, ns, &Entity{URN: "urn:a", Type: TypeTable, Name: "a"})
293+
_, _ = svc.Upsert(ctx, ns, &Entity{URN: "urn:b", Type: TypeTable, Name: "b"})
294+
_, _ = svc.Upsert(ctx, ns, &Entity{URN: "urn:c", Type: TypeTable, Name: "c"})
295+
_, _ = svc.Upsert(ctx, ns, &Entity{URN: "urn:d", Type: TypeTable, Name: "d"})
296+
edges.edges = []Edge{
297+
{SourceURN: "urn:a", TargetURN: "urn:b", Type: "lineage"},
298+
{SourceURN: "urn:b", TargetURN: "urn:c", Type: "lineage"},
299+
{SourceURN: "urn:c", TargetURN: "urn:d", Type: "lineage"},
300+
}
301+
302+
// depth=2 from B should reach A, C, and D
303+
cg, err := svc.GetContext(ctx, ns, "urn:b", 2)
304+
if err != nil {
305+
t.Fatalf("GetContext failed: %v", err)
306+
}
307+
if len(cg.Edges) != 3 {
308+
t.Errorf("expected 3 edges at depth 2, got %d", len(cg.Edges))
309+
}
310+
if len(cg.Related) != 3 {
311+
t.Errorf("expected 3 related entities (a, c, d), got %d", len(cg.Related))
312+
}
313+
}
314+
315+
func TestService_GetContext_MaxDepthCap(t *testing.T) {
316+
repo := newMockRepo()
317+
edges := &mockEdgeRepo{}
318+
svc := NewService(repo, edges, nil)
319+
ctx := context.Background()
320+
ns := namespace.DefaultNamespace
321+
322+
_, _ = svc.Upsert(ctx, ns, &Entity{URN: "urn:a", Type: TypeTable, Name: "a"})
323+
324+
// depth=10 should be capped to maxContextDepth (5), not error
325+
cg, err := svc.GetContext(ctx, ns, "urn:a", 10)
326+
if err != nil {
327+
t.Fatalf("GetContext with large depth should not error: %v", err)
328+
}
329+
if cg.Entity.URN != "urn:a" {
330+
t.Errorf("expected entity urn:a, got %s", cg.Entity.URN)
331+
}
332+
}
333+
334+
func TestService_GetContext_NilEdges(t *testing.T) {
335+
repo := newMockRepo()
336+
svc := NewService(repo, nil, nil) // no edge repo
337+
ctx := context.Background()
338+
ns := namespace.DefaultNamespace
339+
340+
_, _ = svc.Upsert(ctx, ns, &Entity{URN: "urn:a", Type: TypeTable, Name: "a"})
341+
342+
cg, err := svc.GetContext(ctx, ns, "urn:a", 2)
343+
if err != nil {
344+
t.Fatalf("GetContext with nil edges should not error: %v", err)
345+
}
346+
if len(cg.Edges) != 0 {
347+
t.Errorf("expected 0 edges, got %d", len(cg.Edges))
348+
}
349+
if len(cg.Related) != 0 {
350+
t.Errorf("expected 0 related, got %d", len(cg.Related))
351+
}
352+
}
353+
354+
func TestService_GetContext_CycleHandling(t *testing.T) {
355+
repo := newMockRepo()
356+
edges := &mockEdgeRepo{}
357+
svc := NewService(repo, edges, nil)
358+
ctx := context.Background()
359+
ns := namespace.DefaultNamespace
360+
361+
// Cycle: A -> B -> C -> A
362+
_, _ = svc.Upsert(ctx, ns, &Entity{URN: "urn:a", Type: TypeTable, Name: "a"})
363+
_, _ = svc.Upsert(ctx, ns, &Entity{URN: "urn:b", Type: TypeTable, Name: "b"})
364+
_, _ = svc.Upsert(ctx, ns, &Entity{URN: "urn:c", Type: TypeTable, Name: "c"})
365+
edges.edges = []Edge{
366+
{SourceURN: "urn:a", TargetURN: "urn:b", Type: "lineage"},
367+
{SourceURN: "urn:b", TargetURN: "urn:c", Type: "lineage"},
368+
{SourceURN: "urn:c", TargetURN: "urn:a", Type: "lineage"},
369+
}
370+
371+
// depth=3 should not infinite loop
372+
cg, err := svc.GetContext(ctx, ns, "urn:a", 3)
373+
if err != nil {
374+
t.Fatalf("GetContext with cycle should not error: %v", err)
375+
}
376+
if len(cg.Edges) != 3 {
377+
t.Errorf("expected 3 edges in cycle, got %d", len(cg.Edges))
378+
}
379+
if len(cg.Related) != 2 {
380+
t.Errorf("expected 2 related entities (b, c), got %d", len(cg.Related))
381+
}
382+
}
383+
169384
func TestService_Search_NilRepos(t *testing.T) {
170385
svc := NewService(newMockRepo(), nil, nil)
171386
results, err := svc.Search(context.Background(), SearchConfig{Text: "test"})

store/edge_repository.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,41 @@ func (r *EdgeRepository) GetUpstream(ctx context.Context, ns *namespace.Namespac
7777
return r.traverse(ctx, ns, urn, depth, "upstream")
7878
}
7979

80+
func (r *EdgeRepository) GetBidirectional(ctx context.Context, ns *namespace.Namespace, urn string, depth int) ([]entity.Edge, error) {
81+
if depth <= 0 {
82+
depth = 1
83+
}
84+
85+
query := `
86+
WITH RECURSIVE graph(source_urn, target_urn, type, properties, depth, path, frontier) AS (
87+
SELECT source_urn, target_urn, type, properties, 1, ARRAY[source_urn], target_urn
88+
FROM edges
89+
WHERE namespace_id = $1 AND source_urn = $2 AND valid_to IS NULL
90+
UNION ALL
91+
SELECT source_urn, target_urn, type, properties, 1, ARRAY[target_urn], source_urn
92+
FROM edges
93+
WHERE namespace_id = $1 AND target_urn = $2 AND valid_to IS NULL
94+
UNION ALL
95+
SELECT e.source_urn, e.target_urn, e.type, e.properties, g.depth + 1, g.path || g.frontier, e.target_urn
96+
FROM edges e
97+
JOIN graph g ON e.source_urn = g.frontier
98+
WHERE e.target_urn <> ALL(g.path) AND e.valid_to IS NULL AND g.depth < $3
99+
UNION ALL
100+
SELECT e.source_urn, e.target_urn, e.type, e.properties, g.depth + 1, g.path || g.frontier, e.source_urn
101+
FROM edges e
102+
JOIN graph g ON e.target_urn = g.frontier
103+
WHERE e.source_urn <> ALL(g.path) AND e.valid_to IS NULL AND g.depth < $3
104+
)
105+
SELECT DISTINCT source_urn, target_urn, type, properties FROM graph
106+
LIMIT 1000`
107+
108+
var models []edgeModel
109+
if err := r.client.SelectContext(ctx, &models, query, ns.ID, urn, depth); err != nil {
110+
return nil, fmt.Errorf("traverse bidirectional: %w", err)
111+
}
112+
return toEdgeList(models), nil
113+
}
114+
80115
func (r *EdgeRepository) Delete(ctx context.Context, ns *namespace.Namespace, sourceURN, targetURN, edgeType string) error {
81116
_, err := r.client.ExecContext(ctx,
82117
`UPDATE edges SET valid_to = now() WHERE namespace_id = $1 AND source_urn = $2 AND target_urn = $3 AND type = $4 AND valid_to IS NULL`,

0 commit comments

Comments
 (0)