-
Notifications
You must be signed in to change notification settings - Fork 111
Expand file tree
/
Copy pathshared.py
More file actions
88 lines (67 loc) · 2.8 KB
/
shared.py
File metadata and controls
88 lines (67 loc) · 2.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""Shared singletons: config, embedder, and CocoIndex lifecycle."""
from __future__ import annotations
import logging
from collections.abc import Iterator
from dataclasses import dataclass
from typing import TYPE_CHECKING, Annotated
import cocoindex as coco
from cocoindex.connectors import sqlite
from cocoindex.connectors.localfs import FilePath, register_base_dir
from numpy.typing import NDArray
if TYPE_CHECKING:
from cocoindex.ops.litellm import LiteLLMEmbedder
from cocoindex.ops.sentence_transformers import SentenceTransformerEmbedder
from .config import config
logger = logging.getLogger(__name__)
SBERT_PREFIX = "sbert/"
# Initialize embedder at module level based on model prefix
embedder: SentenceTransformerEmbedder | LiteLLMEmbedder
if config.embedding_model.startswith(SBERT_PREFIX):
from cocoindex.ops.sentence_transformers import SentenceTransformerEmbedder
_model_name = config.embedding_model[len(SBERT_PREFIX) :]
# Models that define a "query" prompt for asymmetric retrieval.
_QUERY_PROMPT_MODELS = {"nomic-ai/nomic-embed-code", "nomic-ai/CodeRankEmbed"}
query_prompt_name: str | None = "query" if _model_name in _QUERY_PROMPT_MODELS else None
embedder = SentenceTransformerEmbedder(
_model_name,
device=config.device,
trust_remote_code=True,
)
logger.info(
"Embedding model: %s | device: %s",
config.embedding_model,
config.device,
)
else:
from cocoindex.ops.litellm import LiteLLMEmbedder
embedder = LiteLLMEmbedder(config.embedding_model)
query_prompt_name = None
logger.info("Embedding model (LiteLLM): %s", config.embedding_model)
# Context key for SQLite database (connection managed in lifespan)
SQLITE_DB = coco.ContextKey[sqlite.SqliteDatabase]("sqlite_db")
# Context key for codebase root directory (provided in lifespan)
CODEBASE_DIR = coco.ContextKey[FilePath]("codebase_dir")
@coco.lifespan
def coco_lifespan(builder: coco.EnvironmentBuilder) -> Iterator[None]:
"""Set up database connection."""
# Ensure index directory exists
config.index_dir.mkdir(parents=True, exist_ok=True)
# Set CocoIndex state database path
builder.settings.db_path = config.cocoindex_db_path
# Provide codebase root directory to environment
builder.provide(CODEBASE_DIR, register_base_dir("codebase", config.codebase_root_path))
# Connect to SQLite with vector extension
conn = sqlite.connect(str(config.target_sqlite_db_path), load_vec="auto")
builder.provide(SQLITE_DB, sqlite.register_db("index_db", conn))
yield
conn.close()
@dataclass
class CodeChunk:
"""Schema for storing code chunks in SQLite."""
id: int
file_path: str
language: str
content: str
start_line: int
end_line: int
embedding: Annotated[NDArray, embedder] # type: ignore[type-arg]