Skip to content

Commit ac8a8c7

Browse files
committed
feat: Added a vector rest query interface supporting Cypher
- Added a vector rest query interface supporting Cypher. **vector:** ```vector curl -X POST http://localhost:8000/graph/query/vector \ -H "Content-Type: application/json" \ -d '{ "query": "MATCH (e:Person) RETURN e.name, e.embedding", "column": "e.embedding", "vector": [0.1, 0.2, 0.3], "metric": "cosine", "top_k": 5 }' ``` **query_text:** ```query_text curl -X POST http://localhost:8000/graph/query/vector \ -H "Content-Type: application/json" \ -d '{ "query": "MATCH (e:Person) RETURN e.name, e.embedding", "column": "e.embedding", "query_text": "machine learning researcher", "metric": "cosine", "top_k": 3 }' ```
1 parent 1c199ac commit ac8a8c7

3 files changed

Lines changed: 212 additions & 3 deletions

File tree

python/python/knowledge_graph/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
except ImportError: # pragma: no cover - builder is available in normal installs.
1414
GraphConfigBuilder = object # type: ignore[assignment]
1515

16+
from lance_graph import DistanceMetric, VectorSearch
17+
1618
from .component import KnowledgeGraphComponent
1719
from .config import KnowledgeGraphConfig, build_graph_config_from_mapping
1820
from .extraction import (
@@ -66,6 +68,23 @@ def run(
6668
)
6769
return query.execute(sources)
6870

71+
def run_with_vector_rerank(
72+
self,
73+
statement: str,
74+
vector_search: "VectorSearch",
75+
*,
76+
datasets: Optional[TableMapping] = None,
77+
) -> pa.Table:
78+
"""Execute a Cypher statement and rerank results by vector similarity."""
79+
80+
query = CypherQuery(statement).with_config(self.config)
81+
sources: Dict[str, pa.Table] = dict(self._tables)
82+
if datasets:
83+
sources.update(
84+
{name: _ensure_table(name, table) for name, table in datasets.items()}
85+
)
86+
return query.execute_with_vector_rerank(sources, vector_search)
87+
6988
def tables(self) -> Dict[str, pa.Table]:
7089
"""Return a shallow copy of the registered datasets."""
7190
return dict(self._tables)
@@ -129,4 +148,6 @@ def build(self) -> KnowledgeGraph:
129148
"preview_extraction",
130149
"HeuristicExtractor",
131150
"LLMExtractor",
151+
"VectorSearch",
152+
"DistanceMetric",
132153
]

python/python/knowledge_graph/component.py

Lines changed: 107 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22

33
from __future__ import annotations
44

5-
from typing import Any, Dict, List, Optional
5+
from typing import Any, Dict, List, Literal, Optional
66

77
import pyarrow as pa
88
import yaml
99
from fastapi import APIRouter, HTTPException
10-
from pydantic import BaseModel
10+
from pydantic import BaseModel, Field
1111

1212
from .config import KnowledgeGraphConfig
1313
from .service import LanceKnowledgeGraph
@@ -23,6 +23,46 @@ class QueryResponse(BaseModel):
2323
row_count: int
2424

2525

26+
class VectorQueryRequest(BaseModel):
27+
"""Request body for vector-reranked Cypher queries.
28+
29+
Supply either ``vector`` (raw floats) or ``query_text`` (auto-embedded via OpenAI).
30+
"""
31+
32+
query: str = Field(..., description="Cypher statement to execute.")
33+
column: str = Field(..., description="Name of the vector column to search.")
34+
35+
# Choose one: pass the vector directly or pass the text
36+
# and let the server automatically embed it
37+
vector: Optional[List[float]] = Field(
38+
None,
39+
description="Query vector (float list). Mutually exclusive with query_text.",
40+
)
41+
query_text: Optional[str] = Field(
42+
None, description="Text to embed as query vector. Requires OpenAI API key."
43+
)
44+
45+
metric: Literal["cosine", "l2", "dot"] = Field(
46+
"cosine", description="Distance metric: cosine | l2 | dot."
47+
)
48+
top_k: int = Field(10, ge=1, le=10000, description="Number of nearest neighbours.")
49+
include_distance: bool = Field(
50+
True, description="Include _distance column in results."
51+
)
52+
embedding_model: str = Field(
53+
"text-embedding-3-small",
54+
description="OpenAI embedding model (only used when query_text is provided).",
55+
)
56+
57+
58+
class VectorQueryResponse(BaseModel):
59+
rows: List[Dict[str, Any]]
60+
row_count: int
61+
column: str
62+
metric: str
63+
top_k: int
64+
65+
2666
class DatasetUpsertRequest(BaseModel):
2767
records: List[Dict[str, Any]]
2868
merge: bool = True
@@ -98,6 +138,71 @@ async def get_schema() -> Dict[str, Any]:
98138
payload = yaml.safe_load(handle) or {}
99139
return {"path": str(schema_path), "schema": payload}
100140

141+
@self.router.post("/query/vector", response_model=VectorQueryResponse)
142+
async def execute_vector_query(
143+
request: VectorQueryRequest,
144+
) -> VectorQueryResponse:
145+
"""Execute a Cypher query with vector similarity reranking.
146+
147+
Supply ``vector`` (raw floats) or ``query_text`` (auto-embedded).
148+
"""
149+
if request.vector is None and request.query_text is None:
150+
raise HTTPException(
151+
status_code=400,
152+
detail="Either 'vector' or 'query_text' must be provided.",
153+
)
154+
if request.vector is not None and request.query_text is not None:
155+
raise HTTPException(
156+
status_code=400,
157+
detail="Provide only one of 'vector' or 'query_text', not both.",
158+
)
159+
160+
service = self._get_service()
161+
162+
try:
163+
if request.query_text is not None:
164+
# Text: service internally calls EmbeddingGenerator
165+
result = service.query_by_text(
166+
request.query,
167+
request.query_text,
168+
request.column,
169+
top_k=request.top_k,
170+
metric=request.metric,
171+
include_distance=request.include_distance,
172+
embedding_model=request.embedding_model,
173+
)
174+
else:
175+
# Vector: Constructing VectorSearch directly
176+
from lance_graph import DistanceMetric, VectorSearch
177+
178+
_metric_map = {
179+
"cosine": DistanceMetric.Cosine,
180+
"l2": DistanceMetric.L2,
181+
"dot": DistanceMetric.Dot,
182+
}
183+
vs = (
184+
VectorSearch(request.column)
185+
.query_vector(request.vector)
186+
.metric(_metric_map[request.metric])
187+
.top_k(request.top_k)
188+
.include_distance(request.include_distance)
189+
)
190+
result = service.run_with_vector_rerank(request.query, vs)
191+
192+
except RuntimeError as exc:
193+
raise HTTPException(status_code=500, detail=str(exc)) from exc
194+
except ValueError as exc:
195+
raise HTTPException(status_code=400, detail=str(exc)) from exc
196+
197+
rows = result.to_pylist()
198+
return VectorQueryResponse(
199+
rows=rows,
200+
row_count=len(rows),
201+
column=request.column,
202+
metric=request.metric,
203+
top_k=request.top_k,
204+
)
205+
101206
def close(self) -> None:
102207
"""Release retained resources."""
103208
self._service = None

python/python/knowledge_graph/service.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from typing import TYPE_CHECKING, Iterable, Mapping, MutableMapping, Optional
66

7-
from lance_graph import CypherQuery, GraphConfig
7+
from lance_graph import CypherQuery, DistanceMetric, GraphConfig, VectorSearch
88

99
from .config import KnowledgeGraphConfig, build_default_graph_config
1010
from .store import LanceGraphStore
@@ -141,6 +141,89 @@ def query(
141141
"""Alias for :meth:`run` to match the semantic service naming."""
142142
return self.run(statement, datasets=datasets)
143143

144+
def run_with_vector_rerank(
145+
self,
146+
statement: str,
147+
vector_search: "VectorSearch",
148+
*,
149+
datasets: Optional[Mapping[str, "pa.Table"]] = None,
150+
) -> "pa.Table":
151+
"""Execute a Cypher statement and rerank results by vector similarity.
152+
153+
Parameters
154+
----------
155+
statement:
156+
Cypher query string.
157+
vector_search:
158+
A configured ``VectorSearch`` instance (column, vector, metric, top_k).
159+
datasets:
160+
Optional override tables injected on top of persisted datasets.
161+
"""
162+
query = CypherQuery(statement).with_config(self._config)
163+
164+
referenced_tables = set(query.node_labels()) | set(query.relationship_types())
165+
base_tables: MutableMapping[str, "pa.Table"] = dict(
166+
self._store.load_tables(referenced_tables)
167+
)
168+
if datasets:
169+
base_tables.update(datasets)
170+
return query.execute_with_vector_rerank(base_tables, vector_search)
171+
172+
def query_by_text(
173+
self,
174+
statement: str,
175+
query_text: str,
176+
column: str,
177+
*,
178+
top_k: int = 10,
179+
metric: str = "cosine",
180+
include_distance: bool = True,
181+
embedding_model: str = "text-embedding-3-small",
182+
datasets: Optional[Mapping[str, "pa.Table"]] = None,
183+
) -> "pa.Table":
184+
"""Convenience method: embed ``query_text`` then call run_with_vector_rerank.
185+
186+
Parameters
187+
----------
188+
statement:
189+
Cypher query string.
190+
query_text:
191+
Natural-language text to embed as the query vector.
192+
column:
193+
Name of the vector column in the dataset.
194+
top_k:
195+
Number of nearest neighbours to return.
196+
metric:
197+
Distance metric: "cosine", "l2", or "dot".
198+
include_distance:
199+
Whether to include the ``_distance`` column in results.
200+
embedding_model:
201+
OpenAI embedding model name.
202+
datasets:
203+
Optional override tables.
204+
"""
205+
from .embeddings import EmbeddingGenerator
206+
207+
_metric_map = {
208+
"cosine": DistanceMetric.Cosine,
209+
"l2": DistanceMetric.L2,
210+
"dot": DistanceMetric.Dot,
211+
}
212+
rust_metric = _metric_map.get(metric.lower(), DistanceMetric.Cosine)
213+
214+
vector = EmbeddingGenerator(model=embedding_model).embed_one(query_text)
215+
if vector is None:
216+
raise RuntimeError(f"Failed to generate embedding for text: {query_text!r}")
217+
218+
vs = (
219+
VectorSearch(column)
220+
.query_vector(vector)
221+
.metric(rust_metric)
222+
.top_k(top_k)
223+
.include_distance(include_distance)
224+
)
225+
return self.run_with_vector_rerank(statement, vs, datasets=datasets)
226+
144227

145228
def create_default_service(
146229
config: Optional[KnowledgeGraphConfig] = None,

0 commit comments

Comments
 (0)