|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | """Unit tests for ``aperag.indexing.graph_search_service`` — Wave 7 |
16 | | -task #5. |
| 16 | +task #5 + Wave 8 W8-1 task #12. |
17 | 17 |
|
18 | | -Pins the Wave 7 vector-recall contract: |
| 18 | +Pins the vector-recall contract: |
19 | 19 |
|
20 | 20 | * ``search_entities`` embeds the query, ANN-searches the |
21 | 21 | ``graph_entity`` indexer slice (filter + threshold pinned), then |
22 | 22 | fetches matching ``EntityWithLineage`` rows via per-name |
23 | 23 | ``get_entity`` (asyncio.gather). De-dups payload names so an aliased |
24 | 24 | entity returned twice doesn't double-fetch. |
25 | | -* ``search_relations`` derives relations as the 1-hop expansion of the |
26 | | - vector-recalled entities — vector store carries no per-relation |
27 | | - vectors in Wave 7. |
| 25 | +* ``search_relations`` (Wave 8 W8-1 upgrade) embeds the query, |
| 26 | + ANN-searches the ``graph_relation`` indexer slice, parses each hit's |
| 27 | + ``entity_name="src->tgt"`` + ``entity_type=relation_type`` payload |
| 28 | + per the task #3 writer, then reverse-looks-up |
| 29 | + ``RelationWithLineage`` via ``store.get_relation``. Hits whose |
| 30 | + payload doesn't parse cleanly are skipped silently. Replaces the |
| 31 | + Wave 7 conservative 1-hop expansion path now that task #3 is |
| 32 | + writing relation vectors. |
28 | 33 | * ``get_subgraph`` is a thin pass-through to |
29 | 34 | ``expand_neighbors_n_hops`` for MCP / retrieval callers. |
30 | 35 | * ``compose_context`` renders byte-for-byte the same LightRAG-style |
@@ -124,15 +129,22 @@ def __init__( |
124 | 129 | self, |
125 | 130 | entities: dict[str, EntityWithLineage] | None = None, |
126 | 131 | expansions: dict[tuple[str, ...], tuple[list[EntityWithLineage], list[RelationWithLineage]]] | None = None, |
| 132 | + relations: dict[tuple[str, str, str], RelationWithLineage] | None = None, |
127 | 133 | ) -> None: |
128 | 134 | self._entities = entities or {} |
129 | 135 | self._expansions = expansions or {} |
| 136 | + self._relations = relations or {} |
130 | 137 | self.get_entity_calls: list[str] = [] |
| 138 | + self.get_relation_calls: list[tuple[str, str, str]] = [] |
131 | 139 |
|
132 | 140 | async def get_entity(self, entity_name: str) -> EntityWithLineage | None: |
133 | 141 | self.get_entity_calls.append(entity_name) |
134 | 142 | return self._entities.get(entity_name) |
135 | 143 |
|
| 144 | + async def get_relation(self, source: str, target: str, type: str) -> RelationWithLineage | None: |
| 145 | + self.get_relation_calls.append((source, target, type)) |
| 146 | + return self._relations.get((source, target, type)) |
| 147 | + |
136 | 148 | async def expand_neighbors_n_hops( |
137 | 149 | self, |
138 | 150 | *, |
@@ -185,13 +197,14 @@ def _make_service( |
185 | 197 | *, |
186 | 198 | entities: dict[str, EntityWithLineage] | None = None, |
187 | 199 | expansions: dict[tuple[str, ...], tuple[list[EntityWithLineage], list[RelationWithLineage]]] | None = None, |
| 200 | + relations: dict[tuple[str, str, str], RelationWithLineage] | None = None, |
188 | 201 | hits: list[SearchHit] | None = None, |
189 | 202 | embedder: Any | None = None, |
190 | 203 | connector: Any | None = None, |
191 | 204 | top_k: int = 10, |
192 | 205 | score_threshold: float = 0.0, |
193 | 206 | ) -> tuple[GraphSearchService, _FakeStore, _FakeVectorConnector | Any, _FakeEmbedder | Any]: |
194 | | - store = _FakeStore(entities=entities, expansions=expansions) |
| 207 | + store = _FakeStore(entities=entities, expansions=expansions, relations=relations) |
195 | 208 | connector = connector if connector is not None else _FakeVectorConnector(hits=hits) |
196 | 209 | embedder = embedder if embedder is not None else _FakeEmbedder() |
197 | 210 | service = GraphSearchService( |
@@ -322,31 +335,147 @@ async def test_search_entities_swallows_vector_store_failure(): |
322 | 335 |
|
323 | 336 |
|
324 | 337 | # --------------------------------------------------------------------- |
325 | | -# search_relations |
| 338 | +# search_relations (Wave 8 W8-1 vector recall path) |
326 | 339 | # --------------------------------------------------------------------- |
327 | 340 |
|
328 | 341 |
|
329 | 342 | @pytest.mark.asyncio |
330 | | -async def test_search_relations_empty_when_no_entities_match(): |
331 | | - service, _, _, _ = _make_service(entities={}, hits=[]) |
332 | | - assert await service.search_relations("query") == [] |
| 343 | +async def test_search_relations_empty_query_returns_empty(): |
| 344 | + service, _, connector, embedder = _make_service(entities={}) |
| 345 | + assert await service.search_relations("") == [] |
| 346 | + assert await service.search_relations(" ") == [] |
| 347 | + assert connector.searches == [] |
| 348 | + assert embedder.calls == [] |
333 | 349 |
|
334 | 350 |
|
335 | 351 | @pytest.mark.asyncio |
336 | | -async def test_search_relations_returns_one_hop_expansion_of_entity_results(): |
337 | | - a = _entity("Alpha") |
338 | | - b = _entity("Beta") |
339 | | - rel = _relation("Alpha", "Beta") |
340 | | - service, _, _, _ = _make_service( |
341 | | - entities={"Alpha": a, "Beta": b}, |
| 352 | +async def test_search_relations_zero_topk_returns_empty(): |
| 353 | + service, _, connector, embedder = _make_service(entities={}) |
| 354 | + assert await service.search_relations("query", top_k=0) == [] |
| 355 | + assert connector.searches == [] |
| 356 | + assert embedder.calls == [] |
| 357 | + |
| 358 | + |
| 359 | +@pytest.mark.asyncio |
| 360 | +async def test_search_relations_uses_graph_relation_filter(): |
| 361 | + """Wave 8 W8-1: filter pinned to ``Eq("indexer", "graph_relation")`` |
| 362 | + so the ANN never bleeds into entity / chunk vectors sharing the |
| 363 | + physical collection.""" |
| 364 | + service, _, connector, _ = _make_service(entities={}, hits=[], score_threshold=0.42, top_k=5) |
| 365 | + await service.search_relations("query") |
| 366 | + assert len(connector.searches) == 1 |
| 367 | + request = connector.searches[0].request |
| 368 | + from aperag.indexing.graph_search_service import GRAPH_RELATION_INDEXER |
| 369 | + |
| 370 | + assert request.flt == Eq("indexer", GRAPH_RELATION_INDEXER) |
| 371 | + assert request.score_threshold == 0.42 |
| 372 | + assert request.top_k == 5 |
| 373 | + |
| 374 | + |
| 375 | +@pytest.mark.asyncio |
| 376 | +async def test_search_relations_parses_payload_and_resolves_via_get_relation(): |
| 377 | + """Hit payload carries ``entity_name="src->tgt"`` + |
| 378 | + ``entity_type=relation_type`` (per task #3 writer |
| 379 | + ``aperag/indexing/graph.py:1631``); we split, reverse-lookup via |
| 380 | + ``store.get_relation`` and return the full ``RelationWithLineage`` |
| 381 | + so :meth:`compose_context` keeps its byte-parity rendering.""" |
| 382 | + rel = _relation("Alpha", "Beta", relation_type="founded") |
| 383 | + service, store, _, _ = _make_service( |
| 384 | + relations={("Alpha", "Beta", "founded"): rel}, |
342 | 385 | hits=[ |
343 | | - SearchHit(id="1", score=0.9, payload={"entity_name": "Alpha"}), |
344 | | - SearchHit(id="2", score=0.8, payload={"entity_name": "Beta"}), |
| 386 | + SearchHit( |
| 387 | + id="1", |
| 388 | + score=0.9, |
| 389 | + payload={"entity_name": "Alpha->Beta", "entity_type": "founded"}, |
| 390 | + ), |
345 | 391 | ], |
346 | | - expansions={("Alpha", "Beta"): ([a, b], [rel])}, |
347 | 392 | ) |
348 | 393 | relations = await service.search_relations("query") |
349 | 394 | assert relations == [rel] |
| 395 | + assert store.get_relation_calls == [("Alpha", "Beta", "founded")] |
| 396 | + |
| 397 | + |
| 398 | +@pytest.mark.asyncio |
| 399 | +async def test_search_relations_preserves_hit_order_and_dedupes(): |
| 400 | + rel_ab = _relation("Alpha", "Beta", relation_type="founded") |
| 401 | + rel_bc = _relation("Beta", "Gamma", relation_type="acquired") |
| 402 | + service, store, _, _ = _make_service( |
| 403 | + relations={ |
| 404 | + ("Alpha", "Beta", "founded"): rel_ab, |
| 405 | + ("Beta", "Gamma", "acquired"): rel_bc, |
| 406 | + }, |
| 407 | + hits=[ |
| 408 | + SearchHit(id="1", score=0.9, payload={"entity_name": "Beta->Gamma", "entity_type": "acquired"}), |
| 409 | + SearchHit(id="2", score=0.8, payload={"entity_name": "Alpha->Beta", "entity_type": "founded"}), |
| 410 | + # Duplicate of the first hit (e.g. alias-redirect side-effect) |
| 411 | + # — must dedupe so we don't double-fetch the same edge. |
| 412 | + SearchHit(id="3", score=0.7, payload={"entity_name": "Beta->Gamma", "entity_type": "acquired"}), |
| 413 | + ], |
| 414 | + ) |
| 415 | + relations = await service.search_relations("query") |
| 416 | + assert [(r.source, r.target) for r in relations] == [ |
| 417 | + ("Beta", "Gamma"), |
| 418 | + ("Alpha", "Beta"), |
| 419 | + ] |
| 420 | + assert store.get_relation_calls == [("Beta", "Gamma", "acquired"), ("Alpha", "Beta", "founded")] |
| 421 | + |
| 422 | + |
| 423 | +@pytest.mark.asyncio |
| 424 | +async def test_search_relations_skips_payload_missing_arrow_or_type(): |
| 425 | + """A hit whose payload doesn't parse cleanly is dropped silently — |
| 426 | + better to skip than to surface an edge we can't reconstruct.""" |
| 427 | + rel_ab = _relation("Alpha", "Beta", relation_type="founded") |
| 428 | + service, store, _, _ = _make_service( |
| 429 | + relations={("Alpha", "Beta", "founded"): rel_ab}, |
| 430 | + hits=[ |
| 431 | + SearchHit(id="ghost1", score=1.0, payload={}), # no payload at all |
| 432 | + SearchHit( |
| 433 | + id="ghost2", score=0.99, payload={"entity_name": "AlphaBeta", "entity_type": "founded"} |
| 434 | + ), # missing arrow |
| 435 | + SearchHit(id="ghost3", score=0.98, payload={"entity_name": "Alpha->Beta"}), # missing entity_type |
| 436 | + SearchHit( |
| 437 | + id="ghost4", score=0.97, payload={"entity_name": "->Beta", "entity_type": "founded"} |
| 438 | + ), # empty source |
| 439 | + SearchHit( |
| 440 | + id="ghost5", score=0.96, payload={"entity_name": "Alpha->", "entity_type": "founded"} |
| 441 | + ), # empty target |
| 442 | + SearchHit(id="real", score=0.9, payload={"entity_name": "Alpha->Beta", "entity_type": "founded"}), |
| 443 | + ], |
| 444 | + ) |
| 445 | + relations = await service.search_relations("query") |
| 446 | + assert [(r.source, r.target) for r in relations] == [("Alpha", "Beta")] |
| 447 | + # Only the real hit hit the store. |
| 448 | + assert store.get_relation_calls == [("Alpha", "Beta", "founded")] |
| 449 | + |
| 450 | + |
| 451 | +@pytest.mark.asyncio |
| 452 | +async def test_search_relations_drops_edge_gced_from_store(): |
| 453 | + """Vector hit for an edge that was deleted between sync and search |
| 454 | + → store.get_relation returns None → drop from result, no exception.""" |
| 455 | + service, _, _, _ = _make_service( |
| 456 | + relations={}, # store empty: every get_relation returns None |
| 457 | + hits=[ |
| 458 | + SearchHit(id="1", score=0.9, payload={"entity_name": "Alpha->Beta", "entity_type": "founded"}), |
| 459 | + ], |
| 460 | + ) |
| 461 | + assert await service.search_relations("query") == [] |
| 462 | + |
| 463 | + |
| 464 | +@pytest.mark.asyncio |
| 465 | +async def test_search_relations_swallows_embedder_failure(): |
| 466 | + service, store, connector, _ = _make_service(entities={}, embedder=_FailingEmbedder()) |
| 467 | + assert await service.search_relations("query") == [] |
| 468 | + assert connector.searches == [] |
| 469 | + assert store.get_relation_calls == [] |
| 470 | + |
| 471 | + |
| 472 | +@pytest.mark.asyncio |
| 473 | +async def test_search_relations_swallows_vector_store_failure(): |
| 474 | + service, store, connector, embedder = _make_service(entities={}, connector=_FailingVectorConnector()) |
| 475 | + assert await service.search_relations("query") == [] |
| 476 | + assert embedder.calls == ["query"] |
| 477 | + assert len(connector.searches) == 1 |
| 478 | + assert store.get_relation_calls == [] |
350 | 479 |
|
351 | 480 |
|
352 | 481 | # --------------------------------------------------------------------- |
|
0 commit comments