-
Notifications
You must be signed in to change notification settings - Fork 107
Expand file tree
/
Copy pathquery.py
More file actions
148 lines (126 loc) · 4.57 KB
/
query.py
File metadata and controls
148 lines (126 loc) · 4.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""Query implementation for codebase search."""
import heapq
import sqlite3
from typing import Any
import cocoindex as coco
from .config import config
from .schema import QueryResult
from .shared import SQLITE_DB, embedder, query_prompt_name
def _l2_to_score(distance: float) -> float:
"""Convert L2 distance to cosine similarity (exact for unit vectors)."""
return 1.0 - distance * distance / 2.0
def _knn_query(
conn: sqlite3.Connection,
embedding_bytes: bytes,
k: int,
language: str | None = None,
) -> list[tuple[Any, ...]]:
"""Run a vec0 KNN query, optionally constrained to a language partition."""
if language is not None:
return conn.execute(
"""
SELECT file_path, language, content, start_line, end_line, distance
FROM code_chunks_vec
WHERE embedding MATCH ? AND k = ? AND language = ?
ORDER BY distance
""",
(embedding_bytes, k, language),
).fetchall()
return conn.execute(
"""
SELECT file_path, language, content, start_line, end_line, distance
FROM code_chunks_vec
WHERE embedding MATCH ? AND k = ?
ORDER BY distance
""",
(embedding_bytes, k),
).fetchall()
def _full_scan_query(
conn: sqlite3.Connection,
embedding_bytes: bytes,
limit: int,
offset: int,
languages: list[str] | None = None,
paths: list[str] | None = None,
) -> list[tuple[Any, ...]]:
"""Full scan with SQL-level distance computation and filtering."""
conditions: list[str] = []
params: list[Any] = [embedding_bytes]
if languages:
placeholders = ",".join("?" for _ in languages)
conditions.append(f"language IN ({placeholders})")
params.extend(languages)
if paths:
path_clauses = " OR ".join("file_path GLOB ?" for _ in paths)
conditions.append(f"({path_clauses})")
params.extend(paths)
where = f"WHERE {' AND '.join(conditions)}" if conditions else ""
params.extend([limit, offset])
return conn.execute(
f"""
SELECT file_path, language, content, start_line, end_line,
vec_distance_L2(embedding, ?) as distance
FROM code_chunks_vec
{where}
ORDER BY distance
LIMIT ? OFFSET ?
""",
params,
).fetchall()
async def query_codebase(
query: str,
limit: int = 10,
offset: int = 0,
languages: list[str] | None = None,
paths: list[str] | None = None,
) -> list[QueryResult]:
"""
Perform vector similarity search using vec0 KNN index.
Uses sqlite-vec's vec0 virtual table for indexed nearest-neighbor search.
Language filtering uses vec0 partition keys for exact index-level filtering.
Path filtering triggers a full scan with distance computation.
"""
if not config.target_sqlite_db_path.exists():
raise RuntimeError(
f"Index database not found at {config.target_sqlite_db_path}. "
"Please run a query with refresh_index=True first."
)
coco_env = await coco.default_env()
db = coco_env.get_context(SQLITE_DB)
# Generate query embedding.
query_embedding = await embedder.embed(query, query_prompt_name)
embedding_bytes = query_embedding.astype("float32").tobytes()
with db.value.readonly() as conn:
if paths:
# Path filter → full scan (vec0 can't filter on auxiliary columns).
# LIMIT/OFFSET handled in SQL.
rows = _full_scan_query(conn, embedding_bytes, limit, offset, languages, paths)
elif not languages or len(languages) == 1:
# Single language or no filter: one KNN query.
lang = languages[0] if languages else None
rows = _knn_query(conn, embedding_bytes, limit + offset, lang)
else:
# Multiple languages: separate KNN per partition, merge by distance.
fetch_k = limit + offset
rows = heapq.nsmallest(
fetch_k,
(
row
for lang in languages
for row in _knn_query(conn, embedding_bytes, fetch_k, lang)
),
key=lambda r: r[5], # distance column
)
if not paths:
rows = rows[offset:]
return [
QueryResult(
file_path=file_path,
language=language,
content=content,
start_line=start_line,
end_line=end_line,
score=_l2_to_score(distance),
)
for file_path, language, content, start_line, end_line, distance in rows
]