Skip to content

Commit 5c5811f

Browse files
committed
feat: Added a vector rest query interface supporting Cypher
- Added a vector rest query interface supporting Cypher. **vector:** ```vector curl -X POST http://localhost:8000/graph/query/vector \ -H "Content-Type: application/json" \ -d '{ "query": "MATCH (e:Person) RETURN e.name, e.embedding", "column": "e.embedding", "vector": [0.1, 0.2, 0.3], "metric": "cosine", "top_k": 5 }' ``` **query_text:** ```query_text curl -X POST http://localhost:8000/graph/query/vector \ -H "Content-Type: application/json" \ -d '{ "query": "MATCH (e:Person) RETURN e.name, e.embedding", "column": "e.embedding", "query_text": "machine learning researcher", "metric": "cosine", "top_k": 3 }' ```
1 parent 1c199ac commit 5c5811f

3 files changed

Lines changed: 203 additions & 3 deletions

File tree

python/python/knowledge_graph/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .service import LanceKnowledgeGraph, create_default_service
2626
from .store import LanceGraphStore
2727
from .webservice import create_app
28+
from lance_graph import VectorSearch, DistanceMetric
2829

2930
TableMapping = Mapping[str, pa.Table]
3031

@@ -66,6 +67,23 @@ def run(
6667
)
6768
return query.execute(sources)
6869

70+
def run_with_vector_rerank(
71+
self,
72+
statement: str,
73+
vector_search: "VectorSearch",
74+
*,
75+
datasets: Optional[TableMapping] = None,
76+
) -> pa.Table:
77+
"""Execute a Cypher statement and rerank results by vector similarity."""
78+
79+
query = CypherQuery(statement).with_config(self.config)
80+
sources: Dict[str, pa.Table] = dict(self._tables)
81+
if datasets:
82+
sources.update(
83+
{name: _ensure_table(name, table) for name, table in datasets.items()}
84+
)
85+
return query.execute_with_vector_rerank(sources, vector_search)
86+
6987
def tables(self) -> Dict[str, pa.Table]:
7088
"""Return a shallow copy of the registered datasets."""
7189
return dict(self._tables)
@@ -129,4 +147,6 @@ def build(self) -> KnowledgeGraph:
129147
"preview_extraction",
130148
"HeuristicExtractor",
131149
"LLMExtractor",
150+
"VectorSearch",
151+
"DistanceMetric",
132152
]

python/python/knowledge_graph/component.py

Lines changed: 100 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22

33
from __future__ import annotations
44

5-
from typing import Any, Dict, List, Optional
5+
from typing import Any, Dict, List, Literal, Optional
66

77
import pyarrow as pa
88
import yaml
99
from fastapi import APIRouter, HTTPException
10-
from pydantic import BaseModel
10+
from pydantic import BaseModel, Field
1111

1212
from .config import KnowledgeGraphConfig
1313
from .service import LanceKnowledgeGraph
@@ -22,6 +22,41 @@ class QueryResponse(BaseModel):
2222
rows: List[Dict[str, Any]]
2323
row_count: int
2424

25+
class VectorQueryRequest(BaseModel):
26+
"""Request body for vector-reranked Cypher queries.
27+
28+
Supply either ``vector`` (raw floats) or ``query_text`` (auto-embedded via OpenAI).
29+
"""
30+
31+
query: str = Field(..., description="Cypher statement to execute.")
32+
column: str = Field(..., description="Name of the vector column to search.")
33+
34+
# Choose one: pass the vector directly or pass the text and let the server automatically embed it
35+
vector: Optional[List[float]] = Field(
36+
None, description="Query vector (float list). Mutually exclusive with query_text."
37+
)
38+
query_text: Optional[str] = Field(
39+
None, description="Text to embed as query vector. Requires OpenAI API key."
40+
)
41+
42+
metric: Literal["cosine", "l2", "dot"] = Field(
43+
"cosine", description="Distance metric: cosine | l2 | dot."
44+
)
45+
top_k: int = Field(10, ge=1, le=10000, description="Number of nearest neighbours.")
46+
include_distance: bool = Field(True, description="Include _distance column in results.")
47+
embedding_model: str = Field(
48+
"text-embedding-3-small",
49+
description="OpenAI embedding model (only used when query_text is provided).",
50+
)
51+
52+
53+
class VectorQueryResponse(BaseModel):
54+
rows: List[Dict[str, Any]]
55+
row_count: int
56+
column: str
57+
metric: str
58+
top_k: int
59+
2560

2661
class DatasetUpsertRequest(BaseModel):
2762
records: List[Dict[str, Any]]
@@ -98,6 +133,69 @@ async def get_schema() -> Dict[str, Any]:
98133
payload = yaml.safe_load(handle) or {}
99134
return {"path": str(schema_path), "schema": payload}
100135

136+
@self.router.post("/query/vector", response_model=VectorQueryResponse)
137+
async def execute_vector_query(request: VectorQueryRequest) -> VectorQueryResponse:
138+
"""Execute a Cypher query with vector similarity reranking.
139+
140+
Supply ``vector`` (raw floats) or ``query_text`` (auto-embedded).
141+
"""
142+
if request.vector is None and request.query_text is None:
143+
raise HTTPException(
144+
status_code=400,
145+
detail="Either 'vector' or 'query_text' must be provided.",
146+
)
147+
if request.vector is not None and request.query_text is not None:
148+
raise HTTPException(
149+
status_code=400,
150+
detail="Provide only one of 'vector' or 'query_text', not both.",
151+
)
152+
153+
service = self._get_service()
154+
155+
try:
156+
if request.query_text is not None:
157+
# Text: service internally calls EmbeddingGenerator
158+
result = service.query_by_text(
159+
request.query,
160+
request.query_text,
161+
request.column,
162+
top_k=request.top_k,
163+
metric=request.metric,
164+
include_distance=request.include_distance,
165+
embedding_model=request.embedding_model,
166+
)
167+
else:
168+
# Vector: Constructing VectorSearch directly
169+
from lance_graph import DistanceMetric, VectorSearch
170+
171+
_metric_map = {
172+
"cosine": DistanceMetric.Cosine,
173+
"l2": DistanceMetric.L2,
174+
"dot": DistanceMetric.Dot,
175+
}
176+
vs = (
177+
VectorSearch(request.column)
178+
.query_vector(request.vector)
179+
.metric(_metric_map[request.metric])
180+
.top_k(request.top_k)
181+
.include_distance(request.include_distance)
182+
)
183+
result = service.run_with_vector_rerank(request.query, vs)
184+
185+
except RuntimeError as exc:
186+
raise HTTPException(status_code=500, detail=str(exc)) from exc
187+
except ValueError as exc:
188+
raise HTTPException(status_code=400, detail=str(exc)) from exc
189+
190+
rows = result.to_pylist()
191+
return VectorQueryResponse(
192+
rows=rows,
193+
row_count=len(rows),
194+
column=request.column,
195+
metric=request.metric,
196+
top_k=request.top_k,
197+
)
198+
101199
def close(self) -> None:
102200
"""Release retained resources."""
103201
self._service = None

python/python/knowledge_graph/service.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from typing import TYPE_CHECKING, Iterable, Mapping, MutableMapping, Optional
66

7-
from lance_graph import CypherQuery, GraphConfig
7+
from lance_graph import CypherQuery, GraphConfig, VectorSearch, DistanceMetric
88

99
from .config import KnowledgeGraphConfig, build_default_graph_config
1010
from .store import LanceGraphStore
@@ -141,6 +141,88 @@ def query(
141141
"""Alias for :meth:`run` to match the semantic service naming."""
142142
return self.run(statement, datasets=datasets)
143143

144+
def run_with_vector_rerank(
145+
self,
146+
statement: str,
147+
vector_search: "VectorSearch",
148+
*,
149+
datasets: Optional[Mapping[str, "pa.Table"]] = None,
150+
) -> "pa.Table":
151+
"""Execute a Cypher statement and rerank results by vector similarity.
152+
153+
Parameters
154+
----------
155+
statement:
156+
Cypher query string.
157+
vector_search:
158+
A configured ``VectorSearch`` instance (column, vector, metric, top_k).
159+
datasets:
160+
Optional override tables injected on top of persisted datasets.
161+
"""
162+
query = CypherQuery(statement).with_config(self._config)
163+
164+
referenced_tables = set(query.node_labels()) | set(query.relationship_types())
165+
base_tables: MutableMapping[str, "pa.Table"] = dict(
166+
self._store.load_tables(referenced_tables)
167+
)
168+
if datasets:
169+
base_tables.update(datasets)
170+
return query.execute_with_vector_rerank(base_tables, vector_search)
171+
172+
def query_by_text(
173+
self,
174+
statement: str,
175+
query_text: str,
176+
column: str,
177+
*,
178+
top_k: int = 10,
179+
metric: str = "cosine",
180+
include_distance: bool = True,
181+
embedding_model: str = "text-embedding-3-small",
182+
datasets: Optional[Mapping[str, "pa.Table"]] = None,
183+
) -> "pa.Table":
184+
"""Convenience method: embed ``query_text`` then call run_with_vector_rerank.
185+
186+
Parameters
187+
----------
188+
statement:
189+
Cypher query string.
190+
query_text:
191+
Natural-language text to embed as the query vector.
192+
column:
193+
Name of the vector column in the dataset.
194+
top_k:
195+
Number of nearest neighbours to return.
196+
metric:
197+
Distance metric: "cosine", "l2", or "dot".
198+
include_distance:
199+
Whether to include the ``_distance`` column in results.
200+
embedding_model:
201+
OpenAI embedding model name.
202+
datasets:
203+
Optional override tables.
204+
"""
205+
from .embeddings import EmbeddingGenerator
206+
207+
_metric_map = {
208+
"cosine": DistanceMetric.Cosine,
209+
"l2": DistanceMetric.L2,
210+
"dot": DistanceMetric.Dot,
211+
}
212+
rust_metric = _metric_map.get(metric.lower(), DistanceMetric.Cosine)
213+
214+
vector = EmbeddingGenerator(model=embedding_model).embed_one(query_text)
215+
if vector is None:
216+
raise RuntimeError(f"Failed to generate embedding for text: {query_text!r}")
217+
218+
vs = (
219+
VectorSearch(column)
220+
.query_vector(vector)
221+
.metric(rust_metric)
222+
.top_k(top_k)
223+
.include_distance(include_distance)
224+
)
225+
return self.run_with_vector_rerank(statement, vs, datasets=datasets)
144226

145227
def create_default_service(
146228
config: Optional[KnowledgeGraphConfig] = None,

0 commit comments

Comments
 (0)