diff --git a/src/cocoindex_code/cli.py b/src/cocoindex_code/cli.py index 71ebab9..24e1012 100644 --- a/src/cocoindex_code/cli.py +++ b/src/cocoindex_code/cli.py @@ -3,11 +3,12 @@ from __future__ import annotations import functools +import json as _json import os import sys from collections.abc import Callable from pathlib import Path -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Protocol, TextIO, TypeVar import typer as _typer @@ -102,6 +103,20 @@ def require_project_root() -> Path: _F = TypeVar("_F", bound=Callable[..., object]) +class _SearchCallable(Protocol): + def __call__( + self, + project_root: str, + query: str, + languages: list[str] | None = None, + paths: list[str] | None = None, + repo_keys: list[str] | None = None, + limit: int = 5, + offset: int = 0, + on_waiting: Callable[[], None] | None = None, + ) -> SearchResponse: ... + + def _catch_daemon_start_error(func: _F) -> _F: """Decorator that catches ``DaemonStartError`` and exits with a clean message. @@ -181,6 +196,176 @@ def print_search_results(response: SearchResponse) -> None: _typer.echo(r.content) +def search_response_json_payload(response: SearchResponse) -> dict[str, object]: + """Build the machine-readable search response payload.""" + return { + "success": response.success, + "results": [ + { + "file_path": r.file_path, + "repo_key": r.repo_key, + "language": r.language, + "content": r.content, + "start_line": r.start_line, + "end_line": r.end_line, + "score": r.score, + } + for r in response.results + ], + "total_returned": response.total_returned, + "offset": response.offset, + "message": response.message, + } + + +def print_search_results_json(response: SearchResponse) -> None: + """Print search results as machine-readable JSON.""" + payload = search_response_json_payload(response) + _typer.echo(_json.dumps(payload, indent=2)) + + +def _jsonrpc_id(value: object) -> str | int | None: + if value is None or isinstance(value, str): + return value + if isinstance(value, int) and not isinstance(value, bool): + return value + raise ValueError("JSON-RPC id must be a string, integer, or null") + + +def _jsonrpc_success(request_id: str | int | None, result: object) -> dict[str, object]: + return { + "jsonrpc": "2.0", + "id": request_id, + "result": result, + } + + +def _jsonrpc_error( + request_id: str | int | None, + code: int, + message: str, +) -> dict[str, object]: + return { + "jsonrpc": "2.0", + "id": request_id, + "error": { + "code": code, + "message": message, + }, + } + + +def _required_str(params: dict[str, object], name: str) -> str: + value = params.get(name) + if not isinstance(value, str) or not value: + raise ValueError(f"params.{name} must be a non-empty string") + return value + + +def _optional_str_list(params: dict[str, object], name: str) -> list[str] | None: + value = params.get(name) + if value is None: + return None + if not isinstance(value, list): + raise ValueError(f"params.{name} must be a list of strings") + result: list[str] = [] + for item in value: + if not isinstance(item, str): + raise ValueError(f"params.{name} must be a list of strings") + result.append(item) + return result + + +def _positive_int_param(params: dict[str, object], name: str, default: int) -> int: + value = params.get(name) + if value is None: + return default + if not isinstance(value, int) or isinstance(value, bool) or value <= 0: + raise ValueError(f"params.{name} must be a positive integer") + return value + + +def _non_negative_int_param(params: dict[str, object], name: str, default: int) -> int: + value = params.get(name) + if value is None: + return default + if not isinstance(value, int) or isinstance(value, bool) or value < 0: + raise ValueError(f"params.{name} must be a non-negative integer") + return value + + +def handle_bridge_jsonrpc_request( + request: object, + search_func: _SearchCallable, +) -> tuple[dict[str, object], bool]: + """Handle one JSON-RPC bridge request.""" + request_id: str | int | None = None + try: + if not isinstance(request, dict): + return _jsonrpc_error(None, -32600, "Invalid Request"), False + raw_id = request.get("id") + request_id = _jsonrpc_id(raw_id) + if request.get("jsonrpc") != "2.0": + return _jsonrpc_error(request_id, -32600, "Invalid Request"), False + method = request.get("method") + if not isinstance(method, str): + return _jsonrpc_error(request_id, -32600, "Invalid Request"), False + params_obj = request.get("params", {}) + if not isinstance(params_obj, dict): + return _jsonrpc_error(request_id, -32602, "Invalid params"), False + params = {str(k): v for k, v in params_obj.items()} + + if method == "ping": + return _jsonrpc_success(request_id, {"ok": True}), False + if method == "shutdown": + return _jsonrpc_success(request_id, {"ok": True}), True + if method != "search": + return _jsonrpc_error(request_id, -32601, f"Method not found: {method}"), False + + response = search_func( + project_root=_required_str(params, "project_root"), + query=_required_str(params, "query"), + languages=_optional_str_list(params, "languages"), + paths=_optional_str_list(params, "paths"), + repo_keys=_optional_str_list(params, "repo_keys"), + limit=_positive_int_param(params, "limit", 10), + offset=_non_negative_int_param(params, "offset", 0), + ) + return _jsonrpc_success(request_id, search_response_json_payload(response)), False + except ValueError as e: + return _jsonrpc_error(request_id, -32602, str(e)), False + except RuntimeError as e: + return _jsonrpc_error(request_id, -32000, str(e)), False + + +def run_jsonrpc_bridge( + input_stream: TextIO = sys.stdin, + output_stream: TextIO = sys.stdout, + search_func: _SearchCallable | None = None, +) -> None: + """Run the JSON-RPC bridge over newline-delimited stdin/stdout.""" + if search_func is None: + from . import client as _client + + search_func = _client.search + + for line in input_stream: + stripped = line.strip() + if not stripped: + continue + try: + request = _json.loads(stripped) + except _json.JSONDecodeError: + response = _jsonrpc_error(None, -32700, "Parse error") + should_exit = False + else: + response, should_exit = handle_bridge_jsonrpc_request(request, search_func) + output_stream.write(_json.dumps(response, separators=(",", ":")) + "\n") + output_stream.flush() + if should_exit: + break + + def _run_index_with_progress(project_root: str) -> None: """Run indexing with streaming progress display. Exits on failure.""" from rich.console import Console as _Console @@ -231,6 +416,7 @@ def _search_with_wait_spinner( query: str, languages: list[str] | None = None, paths: list[str] | None = None, + repo_keys: list[str] | None = None, limit: int = 10, offset: int = 0, ) -> SearchResponse: @@ -256,6 +442,7 @@ def _on_waiting() -> None: query=query, languages=languages, paths=paths, + repo_keys=repo_keys, limit=limit, offset=offset, on_waiting=_on_waiting, @@ -549,9 +736,11 @@ def search( query: list[str] = _typer.Argument(..., help="Search query"), lang: list[str] = _typer.Option([], "--lang", help="Filter by language"), path: str | None = _typer.Option(None, "--path", help="Filter by file path glob"), + repo_key: list[str] = _typer.Option([], "--repo-key", help="Filter by indexed repo key"), offset: int = _typer.Option(0, "--offset", help="Number of results to skip"), limit: int = _typer.Option(10, "--limit", help="Maximum results to return"), refresh: bool = _typer.Option(False, "--refresh", help="Refresh index before searching"), + json_output: bool = _typer.Option(False, "--json", help="Print results as JSON"), ) -> None: """Semantic search across the codebase.""" project_root = str(require_project_root()) @@ -574,10 +763,29 @@ def search( query=query_str, languages=lang or None, paths=paths, + repo_keys=repo_key or None, limit=limit, offset=offset, ) - print_search_results(resp) + if json_output: + print_search_results_json(resp) + else: + print_search_results(resp) + + +@app.command() +def bridge( + jsonrpc: bool = _typer.Option( + False, + "--jsonrpc", + help="Run a JSON-RPC bridge over stdin/stdout", + ), +) -> None: + """Run a long-lived bridge for external tools.""" + if not jsonrpc: + _typer.echo("Error: pass --jsonrpc to select the bridge protocol.", err=True) + raise _typer.Exit(code=1) + run_jsonrpc_bridge() @app.command() diff --git a/src/cocoindex_code/client.py b/src/cocoindex_code/client.py index 262af87..5814b91 100644 --- a/src/cocoindex_code/client.py +++ b/src/cocoindex_code/client.py @@ -278,6 +278,7 @@ def search( query: str, languages: list[str] | None = None, paths: list[str] | None = None, + repo_keys: list[str] | None = None, limit: int = 5, offset: int = 0, on_waiting: Callable[[], None] | None = None, @@ -298,6 +299,7 @@ def search( query=query, languages=languages, paths=paths, + repo_keys=repo_keys, limit=limit, offset=offset, ) diff --git a/src/cocoindex_code/daemon.py b/src/cocoindex_code/daemon.py index 41334bc..ff88105 100644 --- a/src/cocoindex_code/daemon.py +++ b/src/cocoindex_code/daemon.py @@ -275,6 +275,7 @@ async def _search_with_wait( query=req.query, languages=req.languages, paths=req.paths, + repo_keys=req.repo_keys, limit=req.limit, offset=req.offset, ) @@ -488,6 +489,7 @@ async def _dispatch( query=req.query, languages=req.languages, paths=req.paths, + repo_keys=req.repo_keys, limit=req.limit, offset=req.offset, ) diff --git a/src/cocoindex_code/indexer.py b/src/cocoindex_code/indexer.py index e028103..ad2823e 100644 --- a/src/cocoindex_code/indexer.py +++ b/src/cocoindex_code/indexer.py @@ -33,6 +33,22 @@ splitter = RecursiveSplitter() +def repo_key_for_path(file_path: PurePath, project_root: Path) -> str: + """Return the relative Git repo root for fast scoped search.""" + directory = file_path.parent + while True: + if (project_root / directory / ".git").exists(): + repo_key = directory.as_posix() + return repo_key if repo_key != "." else "." + + if directory in (PurePath("."), PurePath("")): + break + directory = directory.parent + + parts = file_path.parts + return parts[0] if len(parts) > 1 else "." + + def _normalize_gitignore_lines(lines: Iterable[str], directory: PurePath) -> list[str]: """Normalize .gitignore lines to root-relative gitignore patterns.""" if directory in (PurePath("."), PurePath("")): @@ -151,8 +167,9 @@ async def process_file( if not content.strip(): return - suffix = file.file_path.path.suffix project_root = coco.use_context(CODEBASE_DIR) + suffix = file.file_path.path.suffix + repo_key = repo_key_for_path(file.file_path.path, project_root) ps = load_project_settings(project_root) ext_lang_map = {f".{lo.ext}": lo.lang for lo in ps.language_overrides} language = ( @@ -183,6 +200,7 @@ async def process(chunk: Chunk) -> None: row=CodeChunk( id=await id_gen.next_id(chunk.text), file_path=file.file_path.path.as_posix(), + repo_key=repo_key, language=language, content=chunk.text, start_line=chunk.start.line, @@ -209,7 +227,7 @@ async def indexer_main() -> None: primary_key=["id"], ), virtual_table_def=Vec0TableDef( - partition_key_columns=["language"], + partition_key_columns=["repo_key", "language"], auxiliary_columns=["file_path", "content", "start_line", "end_line"], ), ) diff --git a/src/cocoindex_code/project.py b/src/cocoindex_code/project.py index f661c21..ffb1aff 100644 --- a/src/cocoindex_code/project.py +++ b/src/cocoindex_code/project.py @@ -179,6 +179,7 @@ async def search( query: str, languages: list[str] | None = None, paths: list[str] | None = None, + repo_keys: list[str] | None = None, limit: int = 5, offset: int = 0, ) -> list[SearchResult]: @@ -192,10 +193,12 @@ async def search( offset=offset, languages=languages, paths=paths, + repo_keys=repo_keys, ) return [ SearchResult( file_path=r.file_path, + repo_key=r.repo_key, language=r.language, content=r.content, start_line=r.start_line, diff --git a/src/cocoindex_code/protocol.py b/src/cocoindex_code/protocol.py index b584a4d..024b343 100644 --- a/src/cocoindex_code/protocol.py +++ b/src/cocoindex_code/protocol.py @@ -22,6 +22,7 @@ class SearchRequest(_msgspec.Struct, tag="search"): query: str languages: list[str] | None = None paths: list[str] | None = None + repo_keys: list[str] | None = None limit: int = 5 offset: int = 0 @@ -111,6 +112,7 @@ class SearchResult(_msgspec.Struct): start_line: int end_line: int score: float + repo_key: str | None = None class SearchResponse(_msgspec.Struct, tag="search"): diff --git a/src/cocoindex_code/query.py b/src/cocoindex_code/query.py index a2991ee..d0aade8 100644 --- a/src/cocoindex_code/query.py +++ b/src/cocoindex_code/query.py @@ -21,26 +21,29 @@ def _knn_query( embedding_bytes: bytes, k: int, language: str | None = None, + repo_key: str | None = None, + has_repo_key: bool = False, ) -> list[tuple[Any, ...]]: """Run a vec0 KNN query, optionally constrained to a language partition.""" + conditions = ["embedding MATCH ?", "k = ?"] + params: list[Any] = [embedding_bytes, k] + if repo_key is not None: + conditions.append("repo_key = ?") + params.append(repo_key) if language is not None: - return conn.execute( - """ - SELECT file_path, language, content, start_line, end_line, distance - FROM code_chunks_vec - WHERE embedding MATCH ? AND k = ? AND language = ? - ORDER BY distance - """, - (embedding_bytes, k, language), - ).fetchall() + conditions.append("language = ?") + params.append(language) + + repo_key_select = "repo_key" if has_repo_key else "NULL" return conn.execute( - """ - SELECT file_path, language, content, start_line, end_line, distance + f""" + SELECT file_path, {repo_key_select} as repo_key, + language, content, start_line, end_line, distance FROM code_chunks_vec - WHERE embedding MATCH ? AND k = ? + WHERE {" AND ".join(conditions)} ORDER BY distance """, - (embedding_bytes, k), + params, ).fetchall() @@ -51,27 +54,42 @@ def _full_scan_query( offset: int, languages: list[str] | None = None, paths: list[str] | None = None, + repo_keys: list[str] | None = None, ) -> list[tuple[Any, ...]]: """Full scan with SQL-level distance computation and filtering.""" conditions: list[str] = [] params: list[Any] = [embedding_bytes] + has_repo_key = _table_has_column(conn, "code_chunks_vec", "repo_key") + if languages: placeholders = ",".join("?" for _ in languages) conditions.append(f"language IN ({placeholders})") params.extend(languages) + if repo_keys: + if has_repo_key: + placeholders = ",".join("?" for _ in repo_keys) + conditions.append(f"repo_key IN ({placeholders})") + params.extend(repo_keys) + else: + repo_key_paths = [ + f"{repo_key.rstrip('/')}/*" for repo_key in repo_keys if repo_key != "." + ] + paths = [*(paths or []), *repo_key_paths] or paths + if paths: path_clauses = " OR ".join("file_path GLOB ?" for _ in paths) conditions.append(f"({path_clauses})") params.extend(paths) + repo_key_select = "repo_key" if has_repo_key else "NULL as repo_key" where = f"WHERE {' AND '.join(conditions)}" if conditions else "" params.extend([limit, offset]) return conn.execute( f""" - SELECT file_path, language, content, start_line, end_line, + SELECT file_path, {repo_key_select}, language, content, start_line, end_line, vec_distance_L2(embedding, ?) as distance FROM code_chunks_vec {where} @@ -82,6 +100,22 @@ def _full_scan_query( ).fetchall() +def _table_has_column(conn: sqlite3.Connection, table_name: str, column_name: str) -> bool: + return any(row[1] == column_name for row in conn.execute(f"PRAGMA table_info({table_name})")) + + +def _repo_key_candidates(repo_keys: list[str] | None) -> list[str | None]: + if repo_keys: + return list(repo_keys) + return [None] + + +def _language_candidates(languages: list[str] | None) -> list[str | None]: + if languages: + return list(languages) + return [None] + + async def query_codebase( query: str, target_sqlite_db_path: Path, @@ -90,6 +124,7 @@ async def query_codebase( offset: int = 0, languages: list[str] | None = None, paths: list[str] | None = None, + repo_keys: list[str] | None = None, ) -> list[QueryResult]: """ Perform vector similarity search using vec0 KNN index. @@ -97,6 +132,8 @@ async def query_codebase( Uses sqlite-vec's vec0 virtual table for indexed nearest-neighbor search. Language filtering uses vec0 partition keys for exact index-level filtering. Path filtering triggers a full scan with distance computation. + Repo-key filtering uses the vec0 partition key when available, and + falls back to equivalent path filters for older indexes. """ if not target_sqlite_db_path.exists(): raise RuntimeError( @@ -114,34 +151,46 @@ async def query_codebase( embedding_bytes = query_embedding.astype("float32").tobytes() with db.readonly() as conn: + has_repo_key = _table_has_column(conn, "code_chunks_vec", "repo_key") if paths: - rows = _full_scan_query(conn, embedding_bytes, limit, offset, languages, paths) - elif not languages or len(languages) == 1: + rows = _full_scan_query( + conn, embedding_bytes, limit, offset, languages, paths, repo_keys + ) + elif repo_keys and not has_repo_key: + rows = _full_scan_query( + conn, embedding_bytes, limit, offset, languages, None, repo_keys + ) + elif (not languages or len(languages) == 1) and (not repo_keys or len(repo_keys) == 1): lang = languages[0] if languages else None - rows = _knn_query(conn, embedding_bytes, limit + offset, lang) + repo_key = repo_keys[0] if repo_keys else None + rows = _knn_query(conn, embedding_bytes, limit + offset, lang, repo_key, has_repo_key) else: fetch_k = limit + offset rows = heapq.nsmallest( fetch_k, ( row - for lang in languages - for row in _knn_query(conn, embedding_bytes, fetch_k, lang) + for repo_key in _repo_key_candidates(repo_keys) + for lang in _language_candidates(languages) + for row in _knn_query( + conn, embedding_bytes, fetch_k, lang, repo_key, has_repo_key + ) ), - key=lambda r: r[5], + key=lambda r: r[6], ) - if not paths: + if not paths and not (repo_keys and not has_repo_key): rows = rows[offset:] return [ QueryResult( file_path=file_path, + repo_key=repo_key, language=language, content=content, start_line=start_line, end_line=end_line, score=_l2_to_score(distance), ) - for file_path, language, content, start_line, end_line, distance in rows + for file_path, repo_key, language, content, start_line, end_line, distance in rows ] diff --git a/src/cocoindex_code/schema.py b/src/cocoindex_code/schema.py index bfb8a74..922bcd5 100644 --- a/src/cocoindex_code/schema.py +++ b/src/cocoindex_code/schema.py @@ -10,6 +10,7 @@ class CodeChunk: id: int file_path: str + repo_key: str language: str content: str start_line: int @@ -22,6 +23,7 @@ class QueryResult: """Result from a vector similarity query.""" file_path: str + repo_key: str | None language: str content: str start_line: int diff --git a/src/cocoindex_code/server.py b/src/cocoindex_code/server.py index 2708c86..4b62309 100644 --- a/src/cocoindex_code/server.py +++ b/src/cocoindex_code/server.py @@ -36,6 +36,7 @@ class CodeChunkResult(BaseModel): """A single code chunk result.""" file_path: str = Field(description="Relative path to the file") + repo_key: str | None = Field(default=None, description="Top-level indexed repo/workspace key") language: str = Field(description="Programming language") content: str = Field(description="The code content") start_line: int = Field(description="Starting line number (1-indexed)") @@ -117,6 +118,13 @@ async def search( " Example: ['src/utils/*', '*.py']" ), ), + repo_keys: list[str] | None = Field( + default=None, + description=( + "Filter by indexed top-level repo/workspace key(s). " + "This uses a vector index partition when available." + ), + ), ) -> SearchResultModel: """Query the codebase index via the daemon.""" from . import client as _client @@ -132,6 +140,7 @@ async def search( query=query, languages=languages, paths=paths, + repo_keys=repo_keys, limit=limit, offset=offset, ), @@ -141,6 +150,7 @@ async def search( results=[ CodeChunkResult( file_path=r.file_path, + repo_key=r.repo_key, language=r.language, content=r.content, start_line=r.start_line, diff --git a/src/cocoindex_code/shared.py b/src/cocoindex_code/shared.py index b42e722..170f5d8 100644 --- a/src/cocoindex_code/shared.py +++ b/src/cocoindex_code/shared.py @@ -138,6 +138,7 @@ class CodeChunk: id: int file_path: str + repo_key: str language: str content: str start_line: int diff --git a/tests/test_cli_helpers.py b/tests/test_cli_helpers.py index ec9876a..3598277 100644 --- a/tests/test_cli_helpers.py +++ b/tests/test_cli_helpers.py @@ -2,9 +2,13 @@ from __future__ import annotations +import json +import re +from io import StringIO from pathlib import Path import pytest +from typer.testing import CliRunner from cocoindex_code import cli from cocoindex_code.cli import ( @@ -13,6 +17,13 @@ require_project_root, resolve_default_path, ) +from cocoindex_code.protocol import SearchResponse, SearchResult + +_ANSI_RE = re.compile(r"\x1b\[[0-?]*[ -/]*[@-~]") + + +def _strip_ansi(text: str) -> str: + return _ANSI_RE.sub("", text) def test_require_project_root_success(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: @@ -84,6 +95,223 @@ def test_resolve_default_path_outside_project( assert result is None +def test_search_help_includes_json_option() -> None: + runner = CliRunner() + + result = runner.invoke(cli.app, ["search", "--help"], catch_exceptions=False) + + assert result.exit_code == 0 + output = _strip_ansi(result.output) + assert "--json" in output + assert "--repo-key" in output + + +def test_bridge_help_includes_jsonrpc_option() -> None: + runner = CliRunner() + + result = runner.invoke(cli.app, ["bridge", "--help"], catch_exceptions=False) + + assert result.exit_code == 0 + assert "--jsonrpc" in _strip_ansi(result.output) + + +def test_print_search_results_json_outputs_machine_readable_payload( + capsys: pytest.CaptureFixture[str], +) -> None: + response = SearchResponse( + success=True, + results=[ + SearchResult( + file_path="src/main.py", + language="python", + content="def main():\n return 1", + start_line=10, + end_line=11, + score=0.875, + ) + ], + total_returned=1, + offset=5, + message=None, + ) + + cli.print_search_results_json(response) + + assert json.loads(capsys.readouterr().out) == { + "success": True, + "results": [ + { + "file_path": "src/main.py", + "repo_key": None, + "language": "python", + "content": "def main():\n return 1", + "start_line": 10, + "end_line": 11, + "score": 0.875, + } + ], + "total_returned": 1, + "offset": 5, + "message": None, + } + + +def test_jsonrpc_bridge_ping_and_shutdown() -> None: + input_stream = StringIO( + '{"jsonrpc":"2.0","id":1,"method":"ping"}\n' + '{"jsonrpc":"2.0","id":2,"method":"shutdown"}\n' + '{"jsonrpc":"2.0","id":3,"method":"ping"}\n' + ) + output_stream = StringIO() + + def fake_search( + project_root: str, + query: str, + languages: list[str] | None = None, + paths: list[str] | None = None, + repo_keys: list[str] | None = None, + limit: int = 5, + offset: int = 0, + on_waiting: object | None = None, + ) -> SearchResponse: + raise AssertionError("search should not be called") + + cli.run_jsonrpc_bridge(input_stream, output_stream, fake_search) + + responses = [json.loads(line) for line in output_stream.getvalue().splitlines()] + assert responses == [ + {"jsonrpc": "2.0", "id": 1, "result": {"ok": True}}, + {"jsonrpc": "2.0", "id": 2, "result": {"ok": True}}, + ] + + +def test_jsonrpc_bridge_search_uses_client_payload() -> None: + input_stream = StringIO( + json.dumps( + { + "jsonrpc": "2.0", + "id": "search-1", + "method": "search", + "params": { + "project_root": "/workspace", + "query": "stream writer", + "languages": ["python"], + "paths": ["src/*"], + "repo_keys": ["repo-a"], + "limit": 3, + "offset": 2, + }, + } + ) + + "\n" + ) + output_stream = StringIO() + calls: list[dict[str, object]] = [] + + def fake_search( + project_root: str, + query: str, + languages: list[str] | None = None, + paths: list[str] | None = None, + repo_keys: list[str] | None = None, + limit: int = 5, + offset: int = 0, + on_waiting: object | None = None, + ) -> SearchResponse: + calls.append( + { + "project_root": project_root, + "query": query, + "languages": languages, + "paths": paths, + "repo_keys": repo_keys, + "limit": limit, + "offset": offset, + } + ) + return SearchResponse( + success=True, + results=[ + SearchResult( + file_path="src/main.py", + repo_key="repo-a", + language="python", + content="def stream_writer(): pass", + start_line=4, + end_line=4, + score=0.9, + ) + ], + total_returned=1, + offset=2, + message=None, + ) + + cli.run_jsonrpc_bridge(input_stream, output_stream, fake_search) + + assert calls == [ + { + "project_root": "/workspace", + "query": "stream writer", + "languages": ["python"], + "paths": ["src/*"], + "repo_keys": ["repo-a"], + "limit": 3, + "offset": 2, + } + ] + response = json.loads(output_stream.getvalue()) + assert response == { + "jsonrpc": "2.0", + "id": "search-1", + "result": { + "success": True, + "results": [ + { + "file_path": "src/main.py", + "repo_key": "repo-a", + "language": "python", + "content": "def stream_writer(): pass", + "start_line": 4, + "end_line": 4, + "score": 0.9, + } + ], + "total_returned": 1, + "offset": 2, + "message": None, + }, + } + + +def test_jsonrpc_bridge_returns_parse_error() -> None: + input_stream = StringIO("{not json}\n") + output_stream = StringIO() + + def fake_search( + project_root: str, + query: str, + languages: list[str] | None = None, + paths: list[str] | None = None, + repo_keys: list[str] | None = None, + limit: int = 5, + offset: int = 0, + on_waiting: object | None = None, + ) -> SearchResponse: + raise AssertionError("search should not be called") + + cli.run_jsonrpc_bridge(input_stream, output_stream, fake_search) + + assert json.loads(output_stream.getvalue()) == { + "jsonrpc": "2.0", + "id": None, + "error": { + "code": -32700, + "message": "Parse error", + }, + } + + # --------------------------------------------------------------------------- # .gitignore helpers # --------------------------------------------------------------------------- diff --git a/tests/test_indexer_helpers.py b/tests/test_indexer_helpers.py new file mode 100644 index 0000000..c845c83 --- /dev/null +++ b/tests/test_indexer_helpers.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from pathlib import Path, PurePath + +from cocoindex_code.indexer import repo_key_for_path + + +def test_repo_key_for_path_uses_nested_git_repo_root(tmp_path: Path) -> None: + repo = tmp_path / "ADK" / "a2a-samples" + (repo / ".git").mkdir(parents=True) + + assert repo_key_for_path(PurePath("ADK/a2a-samples/src/main.py"), tmp_path) == ( + "ADK/a2a-samples" + ) + + +def test_repo_key_for_path_uses_root_git_repo(tmp_path: Path) -> None: + (tmp_path / ".git").mkdir() + + assert repo_key_for_path(PurePath("src/main.py"), tmp_path) == "." + + +def test_repo_key_for_path_falls_back_to_top_level_component(tmp_path: Path) -> None: + assert repo_key_for_path(PurePath("workspace/src/main.py"), tmp_path) == "workspace" + assert repo_key_for_path(PurePath("README.md"), tmp_path) == "." diff --git a/tests/test_protocol.py b/tests/test_protocol.py index bf1d216..9e7aafc 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -50,6 +50,7 @@ def test_encode_decode_search_request_with_defaults() -> None: decoded = decode_request(data) assert isinstance(decoded, SearchRequest) assert decoded.languages is None + assert decoded.repo_keys is None assert decoded.limit == 5 assert decoded.offset == 0 @@ -60,6 +61,7 @@ def test_encode_decode_search_request_with_all_fields() -> None: query="hello world", languages=["python", "rust"], paths=["src/*"], + repo_keys=["repo-a"], limit=20, offset=5, ) @@ -70,6 +72,7 @@ def test_encode_decode_search_request_with_all_fields() -> None: assert decoded.query == "hello world" assert decoded.languages == ["python", "rust"] assert decoded.paths == ["src/*"] + assert decoded.repo_keys == ["repo-a"] assert decoded.limit == 20 assert decoded.offset == 5