Skip to content

Commit ad5c157

Browse files
authored
fix: make sure search blocks on load-time indexing gracefully (#92)
1 parent fca7c5a commit ad5c157

3 files changed

Lines changed: 144 additions & 65 deletions

File tree

src/cocoindex_code/cli.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -180,15 +180,14 @@ def _search_with_wait_spinner(
180180
from rich.spinner import Spinner as _Spinner
181181

182182
err_console = _Console(stderr=True)
183-
waiting = False
184183

185-
# Use Live context so the spinner is cleaned up regardless of outcome
186-
with _Live(console=err_console, transient=True) as live:
184+
with _Live(_Spinner("dots", "Searching..."), console=err_console, transient=True) as live:
187185

188186
def _on_waiting() -> None:
189-
nonlocal waiting
190-
waiting = True
191-
live.update(_Spinner("dots", "Waiting for indexing to complete..."))
187+
live.update(
188+
_Spinner("dots", "Waiting for indexing to complete..."),
189+
refresh=True,
190+
)
192191

193192
resp = client.search(
194193
project_root=project_root,

src/cocoindex_code/daemon.py

Lines changed: 92 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import sys
1010
import threading
1111
import time
12-
from collections.abc import AsyncIterator
12+
from collections.abc import AsyncIterator, Callable
1313
from multiprocessing.connection import Connection, Listener
1414
from pathlib import Path
1515
from typing import Any
@@ -102,13 +102,11 @@ class ProjectRegistry:
102102

103103
_projects: dict[str, Project]
104104
_index_locks: dict[str, asyncio.Lock]
105-
_indexing: dict[str, bool]
106105
_embedder: Embedder
107106

108107
def __init__(self, embedder: Embedder) -> None:
109108
self._projects = {}
110109
self._index_locks = {}
111-
self._indexing = {}
112110
self._load_time_done: dict[str, asyncio.Event] = {}
113111
self._embedder = embedder
114112

@@ -127,80 +125,116 @@ async def get_project(self, project_root: str, *, suppress_auto_index: bool = Fa
127125
project = await Project.create(root, project_settings, self._embedder)
128126
self._projects[project_root] = project
129127
self._index_locks[project_root] = asyncio.Lock()
130-
self._indexing[project_root] = False
131-
132-
event = asyncio.Event()
133-
self._load_time_done[project_root] = event
134-
if suppress_auto_index:
135-
event.set()
136-
else:
137-
asyncio.create_task(self._load_time_index(project_root))
128+
self._load_time_done[project_root] = asyncio.Event()
129+
if not suppress_auto_index:
130+
asyncio.create_task(self._run_index(project_root))
138131
return self._projects[project_root]
139132

140-
def is_load_time_indexing(self, project_root: str) -> bool:
141-
"""Check if load-time indexing is in progress."""
133+
def should_wait_for_indexing(self, project_root: str) -> bool:
134+
"""Check if search should wait before querying.
135+
136+
Returns True if the index lock is held (indexing actively running)
137+
or the initial indexing hasn't completed yet (covers the window
138+
between task creation and lock acquisition).
139+
"""
140+
lock = self._index_locks.get(project_root)
141+
if lock is not None and lock.locked():
142+
return True
142143
event = self._load_time_done.get(project_root)
143144
return event is not None and not event.is_set()
144145

145-
async def wait_for_load_time_indexing(self, project_root: str) -> None:
146-
"""Wait for load-time indexing to complete. Returns immediately if not in progress."""
146+
async def wait_for_indexing_done(self, project_root: str) -> None:
147+
"""Wait until no indexing is in progress and initial indexing is complete."""
148+
# Wait for the initial indexing to complete (if pending)
147149
event = self._load_time_done.get(project_root)
148150
if event is not None:
149151
await event.wait()
152+
# Wait for any ongoing indexing to finish (lock released)
153+
lock = self._index_locks.get(project_root)
154+
if lock is not None and lock.locked():
155+
await lock.acquire()
156+
lock.release()
150157

151-
async def _load_time_index(self, project_root: str) -> None:
152-
"""Background load-time indexing, consuming the update_index stream."""
158+
async def _run_index(
159+
self,
160+
project_root: str,
161+
on_progress: Callable[[IndexingProgress], None] | None = None,
162+
) -> None:
163+
"""Run indexing for a project, acquiring and releasing the per-project lock.
164+
165+
This is the single place where indexing actually happens. It is used
166+
both as a fire-and-forget background task (load-time indexing) and as a
167+
spawned task inside ``update_index`` (client-driven indexing).
168+
169+
On completion (success or failure) it marks load-time as done
170+
(idempotent) and releases the lock.
171+
"""
172+
project = self._projects[project_root]
173+
lock = self._index_locks[project_root]
174+
175+
await lock.acquire()
153176
try:
154-
async for _ in self.update_index(project_root):
155-
pass
177+
await project.update_index(
178+
on_progress=on_progress,
179+
)
156180
except Exception:
157-
logger.exception("Load-time indexing failed for %s", project_root)
181+
logger.exception("Indexing failed for %s", project_root)
158182
finally:
159183
event = self._load_time_done.get(project_root)
160184
if event is not None:
161185
event.set()
186+
lock.release()
162187

163188
async def update_index(
164189
self, project_root: str, *, suppress_auto_index: bool = True
165190
) -> AsyncIterator[IndexStreamResponse]:
166-
"""Update index, yielding progress updates and a final IndexResponse."""
167-
project = await self.get_project(project_root, suppress_auto_index=suppress_auto_index)
191+
"""Update index, yielding progress updates and a final IndexResponse.
192+
193+
Streams ``IndexProgressUpdate`` messages while indexing is in progress,
194+
ending with a terminal ``IndexResponse``. If the lock is already held,
195+
yields ``IndexWaitingNotice`` first.
196+
197+
The actual indexing runs in a separate task (``_run_index``) so that
198+
client disconnects (``GeneratorExit``) do not abort the indexing.
199+
"""
200+
await self.get_project(project_root, suppress_auto_index=suppress_auto_index)
168201
lock = self._index_locks[project_root]
169202

170-
# If lock is already held, notify the client and block until released
203+
# If lock is already held, notify the client before blocking
171204
if lock.locked():
172205
yield IndexWaitingNotice()
173206

174-
async with lock:
175-
self._indexing[project_root] = True
176-
try:
177-
progress_queue: asyncio.Queue[IndexingProgress] = asyncio.Queue()
178-
179-
def on_progress(progress: IndexingProgress) -> None:
180-
progress_queue.put_nowait(progress)
181-
182-
update_task = asyncio.create_task(project.update_index(on_progress=on_progress))
183-
184-
# Drain the queue until the update completes
185-
while not update_task.done():
186-
try:
187-
progress = await asyncio.wait_for(progress_queue.get(), timeout=0.1)
188-
yield IndexProgressUpdate(progress=progress)
189-
except TimeoutError:
190-
continue
191-
192-
# Drain any remaining items
193-
while not progress_queue.empty():
194-
yield IndexProgressUpdate(progress=progress_queue.get_nowait())
195-
196-
# Propagate any exception from the update task
197-
update_task.result()
207+
progress_queue: asyncio.Queue[IndexingProgress] = asyncio.Queue()
208+
index_task = asyncio.create_task(
209+
self._run_index(
210+
project_root,
211+
on_progress=lambda p: progress_queue.put_nowait(p),
212+
)
213+
)
198214

199-
yield IndexResponse(success=True)
200-
except Exception as e:
201-
yield IndexResponse(success=False, message=str(e))
202-
finally:
203-
self._indexing[project_root] = False
215+
try:
216+
# Drain the queue until the task completes
217+
while not index_task.done():
218+
try:
219+
progress = await asyncio.wait_for(progress_queue.get(), timeout=0.1)
220+
yield IndexProgressUpdate(progress=progress)
221+
except TimeoutError:
222+
continue
223+
224+
# Drain any remaining items
225+
while not progress_queue.empty():
226+
yield IndexProgressUpdate(progress=progress_queue.get_nowait())
227+
228+
# Propagate any exception from the index task
229+
index_task.result()
230+
231+
yield IndexResponse(success=True)
232+
except GeneratorExit:
233+
# Client disconnected — _run_index continues in background and
234+
# handles cleanup (release lock, clear _indexing) when done.
235+
return
236+
except Exception as e:
237+
yield IndexResponse(success=False, message=str(e))
204238

205239
async def search(
206240
self,
@@ -255,7 +289,8 @@ def get_status(self, project_root: str) -> ProjectStatusResponse:
255289
" GROUP BY language ORDER BY cnt DESC"
256290
).fetchall()
257291

258-
is_indexing = self._indexing.get(project_root, False)
292+
lock = self._index_locks.get(project_root)
293+
is_indexing = lock is not None and lock.locked()
259294
progress = project.indexing_stats if is_indexing else None
260295
return ProjectStatusResponse(
261296
indexing=is_indexing,
@@ -272,7 +307,6 @@ def remove_project(self, project_root: str) -> bool:
272307
was_loaded = project_root in self._projects
273308
project = self._projects.pop(project_root, None)
274309
self._index_locks.pop(project_root, None)
275-
self._indexing.pop(project_root, None)
276310
self._load_time_done.pop(project_root, None)
277311
if project is not None:
278312
project.close()
@@ -288,7 +322,6 @@ def close_all(self) -> None:
288322
project.close()
289323
self._projects.clear()
290324
self._index_locks.clear()
291-
self._indexing.clear()
292325
self._load_time_done.clear()
293326
gc.collect()
294327

@@ -297,7 +330,7 @@ def list_projects(self) -> list[DaemonProjectInfo]:
297330
return [
298331
DaemonProjectInfo(
299332
project_root=root,
300-
indexing=self._indexing.get(root, False),
333+
indexing=self._index_locks[root].locked(),
301334
)
302335
for root in self._projects
303336
]
@@ -384,9 +417,9 @@ def _recv() -> bytes:
384417
async def _search_with_wait(
385418
registry: ProjectRegistry, req: SearchRequest
386419
) -> AsyncIterator[SearchStreamResponse]:
387-
"""Stream search response, waiting for load-time indexing first."""
420+
"""Stream search response, waiting for ongoing indexing first."""
388421
yield IndexWaitingNotice()
389-
await registry.wait_for_load_time_indexing(req.project_root)
422+
await registry.wait_for_indexing_done(req.project_root)
390423
try:
391424
results = await registry.search(
392425
project_root=req.project_root,
@@ -427,7 +460,7 @@ async def _dispatch(
427460
await registry.get_project(req.project_root)
428461

429462
# If load-time indexing is in progress, return a streaming response
430-
if registry.is_load_time_indexing(req.project_root):
463+
if registry.should_wait_for_indexing(req.project_root):
431464
return _search_with_wait(registry, req)
432465

433466
results = await registry.search(

tests/test_daemon.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,53 @@ def test_daemon_remove_project_not_loaded(daemon_sock: str) -> None:
238238
conn.close()
239239

240240

241+
def test_daemon_search_waits_during_explicit_index(daemon_sock: str) -> None:
242+
"""When IndexRequest is in progress, a concurrent SearchRequest should receive
243+
IndexWaitingNotice (Path B: index first, then search)."""
244+
# Use enough files to ensure indexing takes long enough for the search to
245+
# arrive while it's still in progress.
246+
project = Path(tempfile.mkdtemp(prefix="ccc_idx_then_search_"))
247+
save_project_settings(project, default_project_settings())
248+
for i in range(20):
249+
(project / f"module_{i}.py").write_text(
250+
f'"""Module {i}."""\n\ndef func_{i}(x: int) -> int:\n'
251+
f' """Compute something for module {i}."""\n'
252+
f" return x * {i} + {i}\n"
253+
)
254+
255+
# Connection 1: start indexing (don't wait for completion)
256+
conn1, _ = _connect_and_handshake(daemon_sock)
257+
conn1.send_bytes(encode_request(IndexRequest(project_root=str(project))))
258+
259+
# Send the search request immediately — the daemon processes requests
260+
# concurrently across connections, and _run_index needs to acquire the
261+
# lock before indexing starts, so a prompt SearchRequest will arrive
262+
# while the event is still unset.
263+
conn2, _ = _connect_and_handshake(daemon_sock)
264+
conn2.send_bytes(encode_request(SearchRequest(project_root=str(project), query="compute")))
265+
266+
got_waiting = False
267+
final_resp: SearchResponse | None = None
268+
while True:
269+
resp = decode_response(conn2.recv_bytes())
270+
if isinstance(resp, IndexWaitingNotice):
271+
got_waiting = True
272+
continue
273+
if isinstance(resp, SearchResponse):
274+
final_resp = resp
275+
break
276+
raise AssertionError(f"Unexpected response on search conn: {type(resp).__name__}")
277+
278+
assert got_waiting, "Expected IndexWaitingNotice before SearchResponse"
279+
assert final_resp is not None
280+
assert final_resp.success is True
281+
282+
# Drain the index stream on connection 1
283+
_recv_index_response(conn1)
284+
conn1.close()
285+
conn2.close()
286+
287+
241288
def test_daemon_search_waits_for_load_time_indexing(daemon_sock: str) -> None:
242289
"""Search on a fresh project should wait for load-time indexing, sending IndexWaitingNotice."""
243290
# Create a new project that the daemon hasn't seen — its first load will

0 commit comments

Comments
 (0)